• 如果您觉得本站非常有看点,那么赶紧使用Ctrl+D 收藏吧

Java AttributeDataset类的典型用法和代码示例

java 1次浏览

本文整理汇总了Java中smile.data.AttributeDataset的典型用法代码示例。如果您正苦于以下问题:Java AttributeDataset类的具体用法?Java AttributeDataset怎么用?Java AttributeDataset使用的例子?那么恭喜您, 这里精选的类代码示例或许可以为您提供帮助。

AttributeDataset类属于smile.data包,在下文中一共展示了AttributeDataset类的40个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于我们的系统推荐出更棒的Java代码示例。

示例1: main

点赞 3

import smile.data.AttributeDataset; //导入依赖的package包/类
public static void main(String[] args) {
    try {
        RESParser parser = new RESParser();
        AttributeDataset data = parser.parse("RES", smile.data.parser.IOUtils.getTestDataFile("microarray/all_aml_test.res"));
        
        double[][] x = data.toArray(new double[data.size()][]);
        String[] genes = data.toArray(new String[data.size()]);
        String[] arrays = new String[data.attributes().length];
        for (int i = 0; i < arrays.length; i++) {
            arrays[i] = data.attributes()[i].getName();
        }

        JFrame frame = new JFrame("Heatmap");
        frame.setDefaultCloseOperation(JFrame.EXIT_ON_CLOSE);
        frame.setLocationRelativeTo(null);
        frame.getContentPane().add(Heatmap.plot(genes, arrays, x, Palette.jet(256)));
        frame.setVisible(true);
    } catch (Exception ex) {
        System.err.println(ex);
    }
}
 

开发者ID:takun2s,
项目名称:smile_1.5.0_java7,
代码行数:22,
代码来源:HeatmapDemo.java

示例2: testTest_3args_1

点赞 3

import smile.data.AttributeDataset; //导入依赖的package包/类
/**
 * Test of test method, of class Validation.
 */
@Test
public void testTest_3args_1() {
    System.out.println("test");
    DelimitedTextParser parser = new DelimitedTextParser();
    parser.setResponseIndex(new NominalAttribute("class"), 0);
    try {
        AttributeDataset train = parser.parse("USPS Train", smile.data.parser.IOUtils.getTestDataFile("usps/zip.train"));
        AttributeDataset test = parser.parse("USPS Test", smile.data.parser.IOUtils.getTestDataFile("usps/zip.test"));

        double[][] x = train.toArray(new double[train.size()][]);
        int[] y = train.toArray(new int[train.size()]);
        double[][] testx = test.toArray(new double[test.size()][]);
        int[] testy = test.toArray(new int[test.size()]);

        LDA lda = new LDA(x, y);
        double accuracy = Validation.test(lda, testx, testy);
        System.out.println("accuracy = " + accuracy);
        assertEquals(0.8724, accuracy, 1E-4);
    } catch (Exception ex) {
        System.err.println(ex);
    }
}
 

开发者ID:takun2s,
项目名称:smile_1.5.0_java7,
代码行数:26,
代码来源:ValidationTest.java

示例3: testTest_4args_1

点赞 3

import smile.data.AttributeDataset; //导入依赖的package包/类
/**
 * Test of test method, of class Validation.
 */
@Test
public void testTest_4args_1() {
    System.out.println("test");
    DelimitedTextParser parser = new DelimitedTextParser();
    parser.setResponseIndex(new NominalAttribute("class"), 0);
    try {
        AttributeDataset train = parser.parse("USPS Train", smile.data.parser.IOUtils.getTestDataFile("usps/zip.train"));
        AttributeDataset test = parser.parse("USPS Test", smile.data.parser.IOUtils.getTestDataFile("usps/zip.test"));

        double[][] x = train.toArray(new double[train.size()][]);
        int[] y = train.toArray(new int[train.size()]);
        double[][] testx = test.toArray(new double[test.size()][]);
        int[] testy = test.toArray(new int[test.size()]);

        LDA lda = new LDA(x, y);
        ClassificationMeasure[] measures = {new Accuracy()};
        double[] accuracy = Validation.test(lda, testx, testy, measures);
        System.out.println("accuracy = " + accuracy[0]);
        assertEquals(0.8724, accuracy[0], 1E-4);
    } catch (Exception ex) {
        System.err.println(ex);
    }
}
 

开发者ID:takun2s,
项目名称:smile_1.5.0_java7,
代码行数:27,
代码来源:ValidationTest.java

示例4: testLoocv_3args_1

点赞 3

import smile.data.AttributeDataset; //导入依赖的package包/类
/**
 * Test of loocv method, of class Validation.
 */
@Test
public void testLoocv_3args_1() {
    System.out.println("loocv");
    ArffParser arffParser = new ArffParser();
    arffParser.setResponseIndex(4);
    try {
        AttributeDataset iris = arffParser.parse(smile.data.parser.IOUtils.getTestDataFile("weka/iris.arff"));
        double[][] x = iris.toArray(new double[iris.size()][]);
        int[] y = iris.toArray(new int[iris.size()]);

        ClassifierTrainer<double[]> trainer = new LDA.Trainer();
        double accuracy = Validation.loocv(trainer, x, y);
        
        System.out.println("LOOCV accuracy = " + accuracy);
        assertEquals(0.8533, accuracy, 1E-4);
    } catch (Exception ex) {
        System.err.println(ex);
    }
}
 

开发者ID:takun2s,
项目名称:smile_1.5.0_java7,
代码行数:23,
代码来源:ValidationTest.java

示例5: testLoocv_3args_2

点赞 3

import smile.data.AttributeDataset; //导入依赖的package包/类
/**
 * Test of loocv method, of class Validation.
 */
@Test
public void testLoocv_3args_2() {
    System.out.println("loocv");
    ArffParser parser = new ArffParser();
    parser.setResponseIndex(6);
    try {
        AttributeDataset data = parser.parse(smile.data.parser.IOUtils.getTestDataFile("weka/cpu.arff"));
        double[] y = data.toArray(new double[data.size()]);
        double[][] x = data.toArray(new double[data.size()][]);
        Math.standardize(x);

        RBFNetwork.Trainer<double[]> trainer = new RBFNetwork.Trainer<>(new EuclideanDistance());
        trainer.setNumCenters(20);
        double rmse = Validation.loocv(trainer, x, y);
        System.out.println("RMSE = " + rmse);
    } catch (Exception ex) {
        System.err.println(ex);
    }
}
 

开发者ID:takun2s,
项目名称:smile_1.5.0_java7,
代码行数:23,
代码来源:ValidationTest.java

示例6: testCv_4args_1

点赞 3

import smile.data.AttributeDataset; //导入依赖的package包/类
/**
 * Test of cv method, of class Validation.
 */
@Test
public void testCv_4args_1() {
    System.out.println("cv");
    ArffParser arffParser = new ArffParser();
    arffParser.setResponseIndex(4);
    try {
        AttributeDataset iris = arffParser.parse(smile.data.parser.IOUtils.getTestDataFile("weka/iris.arff"));
        double[][] x = iris.toArray(new double[iris.size()][]);
        int[] y = iris.toArray(new int[iris.size()]);

        ClassifierTrainer<double[]> trainer = new LDA.Trainer();
        double accuracy = Validation.cv(10, trainer, x, y);
        
        System.out.println("10-fold CV accuracy = " + accuracy);
    } catch (Exception ex) {
        System.err.println(ex);
    }
}
 

开发者ID:takun2s,
项目名称:smile_1.5.0_java7,
代码行数:22,
代码来源:ValidationTest.java

示例7: testCv_4args_2

点赞 3

import smile.data.AttributeDataset; //导入依赖的package包/类
/**
 * Test of cv method, of class Validation.
 */
@Test
public void testCv_4args_2() {
    System.out.println("cv");
    ArffParser parser = new ArffParser();
    parser.setResponseIndex(6);
    try {
        AttributeDataset data = parser.parse(smile.data.parser.IOUtils.getTestDataFile("weka/cpu.arff"));
        double[] y = data.toArray(new double[data.size()]);
        double[][] x = data.toArray(new double[data.size()][]);
        Math.standardize(x);

        RBFNetwork.Trainer<double[]> trainer = new RBFNetwork.Trainer<>(new EuclideanDistance());
        trainer.setNumCenters(20);
        double rmse = Validation.cv(10, trainer, x, y);
        System.out.println("RMSE = " + rmse);
    } catch (Exception ex) {
        System.err.println(ex);
    }
}
 

开发者ID:takun2s,
项目名称:smile_1.5.0_java7,
代码行数:23,
代码来源:ValidationTest.java

示例8: testCv_5args_1

点赞 3

import smile.data.AttributeDataset; //导入依赖的package包/类
/**
 * Test of cv method, of class Validation.
 */
@Test
public void testCv_5args_1() {
    System.out.println("cv");
    ArffParser arffParser = new ArffParser();
    arffParser.setResponseIndex(4);
    try {
        AttributeDataset iris = arffParser.parse(smile.data.parser.IOUtils.getTestDataFile("weka/iris.arff"));
        double[][] x = iris.toArray(new double[iris.size()][]);
        int[] y = iris.toArray(new int[iris.size()]);

        ClassifierTrainer<double[]> trainer = new LDA.Trainer();
        ClassificationMeasure[] measures = {new Accuracy()};
        double[] results = Validation.cv(10, trainer, x, y, measures);
        for (int i = 0; i < measures.length; i++) {
            System.out.println(measures[i] + " = " + results[i]);
        }
    } catch (Exception ex) {
        System.err.println(ex);
    }
}
 

开发者ID:takun2s,
项目名称:smile_1.5.0_java7,
代码行数:24,
代码来源:ValidationTest.java

示例9: testCv_5args_2

点赞 3

import smile.data.AttributeDataset; //导入依赖的package包/类
/**
 * Test of cv method, of class Validation.
 */
@Test
public void testCv_5args_2() {
    System.out.println("cv");
    ArffParser parser = new ArffParser();
    parser.setResponseIndex(6);
    try {
        AttributeDataset data = parser.parse(smile.data.parser.IOUtils.getTestDataFile("weka/cpu.arff"));
        double[] y = data.toArray(new double[data.size()]);
        double[][] x = data.toArray(new double[data.size()][]);
        Math.standardize(x);

        RBFNetwork.Trainer<double[]> trainer = new RBFNetwork.Trainer<>(new EuclideanDistance());
        trainer.setNumCenters(20);
        RegressionMeasure[] measures = {new RMSE(), new MeanAbsoluteDeviation()};
        double[] results = Validation.cv(10, trainer, x, y, measures);
        System.out.println("RMSE = " + results[0]);
        System.out.println("MAD = " + results[1]);
    } catch (Exception ex) {
        System.err.println(ex);
    }
}
 

开发者ID:takun2s,
项目名称:smile_1.5.0_java7,
代码行数:25,
代码来源:ValidationTest.java

示例10: testBootstrap_4args_1

点赞 3

import smile.data.AttributeDataset; //导入依赖的package包/类
/**
 * Test of bootstrap method, of class Validation.
 */
@Test
public void testBootstrap_4args_1() {
    System.out.println("bootstrap");
    ArffParser arffParser = new ArffParser();
    arffParser.setResponseIndex(4);
    try {
        AttributeDataset iris = arffParser.parse(smile.data.parser.IOUtils.getTestDataFile("weka/iris.arff"));
        double[][] x = iris.toArray(new double[iris.size()][]);
        int[] y = iris.toArray(new int[iris.size()]);

        ClassifierTrainer<double[]> trainer = new LDA.Trainer();
        double[] accuracy = Validation.bootstrap(100, trainer, x, y);
        
        System.out.println("100-fold bootstrap accuracy average = " + Math.mean(accuracy));
        System.out.println("100-fold bootstrap accuracy std.dev = " + Math.sd(accuracy));
    } catch (Exception ex) {
        System.err.println(ex);
    }
}
 

开发者ID:takun2s,
项目名称:smile_1.5.0_java7,
代码行数:23,
代码来源:ValidationTest.java

示例11: testBootstrap_4args_2

点赞 3

import smile.data.AttributeDataset; //导入依赖的package包/类
/**
 * Test of bootstrap method, of class Validation.
 */
@Test
public void testBootstrap_4args_2() {
    System.out.println("bootstrap");
    ArffParser parser = new ArffParser();
    parser.setResponseIndex(6);
    try {
        AttributeDataset data = parser.parse(smile.data.parser.IOUtils.getTestDataFile("weka/cpu.arff"));
        double[] y = data.toArray(new double[data.size()]);
        double[][] x = data.toArray(new double[data.size()][]);
        Math.standardize(x);

        RBFNetwork.Trainer<double[]> trainer = new RBFNetwork.Trainer<>(new EuclideanDistance());
        trainer.setNumCenters(20);
        double[] rmse = Validation.bootstrap(100, trainer, x, y);
        System.out.println("100-fold bootstrap RMSE average = " + Math.mean(rmse));
        System.out.println("100-fold bootstrap RMSE std.dev = " + Math.sd(rmse));
    } catch (Exception ex) {
        System.err.println(ex);
    }
}
 

开发者ID:takun2s,
项目名称:smile_1.5.0_java7,
代码行数:24,
代码来源:ValidationTest.java

示例12: testBootstrap_5args_2

点赞 3

import smile.data.AttributeDataset; //导入依赖的package包/类
/**
 * Test of bootstrap method, of class Validation.
 */
@Test
public void testBootstrap_5args_2() {
    System.out.println("bootstrap");
    ArffParser parser = new ArffParser();
    parser.setResponseIndex(6);
    try {
        AttributeDataset data = parser.parse(smile.data.parser.IOUtils.getTestDataFile("weka/cpu.arff"));
        double[] y = data.toArray(new double[data.size()]);
        double[][] x = data.toArray(new double[data.size()][]);
        Math.standardize(x);

        RBFNetwork.Trainer<double[]> trainer = new RBFNetwork.Trainer<>(new EuclideanDistance());
        trainer.setNumCenters(20);
        RegressionMeasure[] measures = {new RMSE(), new MeanAbsoluteDeviation()};
        double[][] results = Validation.bootstrap(100, trainer, x, y, measures);
        System.out.println("100-fold bootstrap RMSE average = " + Math.mean(results[0]));
        System.out.println("100-fold bootstrap RMSE std.dev = " + Math.sd(results[0]));
        System.out.println("100-fold bootstrap AbsoluteDeviation average = " + Math.mean(results[1]));
        System.out.println("100-fold bootstrap AbsoluteDeviation std.dev = " + Math.sd(results[1]));
    } catch (Exception ex) {
        System.err.println(ex);
    }
}
 

开发者ID:takun2s,
项目名称:smile_1.5.0_java7,
代码行数:27,
代码来源:ValidationTest.java

示例13: testAttributes

点赞 3

import smile.data.AttributeDataset; //导入依赖的package包/类
/**
 * Test of attributes method, of class DateFeature.
 */
@Test
public void testAttributes() {
    System.out.println("attributes");
    try {
        ArffParser parser = new ArffParser();
        AttributeDataset data = parser.parse(smile.data.parser.IOUtils.getTestDataFile("weka/date.arff"));
        
        DateFeature.Type[] features = {DateFeature.Type.YEAR, DateFeature.Type.MONTH, DateFeature.Type.DAY_OF_MONTH, DateFeature.Type.DAY_OF_WEEK, DateFeature.Type.HOURS, DateFeature.Type.MINUTES, DateFeature.Type.SECONDS};
        DateFeature df = new DateFeature(features);
        Attribute[] attributes = df.attributes();
        assertEquals(features.length, attributes.length);
        for (int i = 0; i < attributes.length; i++) {
            System.out.println(attributes[i]);
            if (i == 1 || i == 3) {
                assertEquals(Attribute.Type.NOMINAL, attributes[i].getType());
            } else {
                assertEquals(Attribute.Type.NUMERIC, attributes[i].getType());
            }
        }
    } catch (Exception ex) {
        System.err.println(ex);
    }
}
 

开发者ID:takun2s,
项目名称:smile_1.5.0_java7,
代码行数:27,
代码来源:DateFeatureTest.java

示例14: testRank

点赞 3

import smile.data.AttributeDataset; //导入依赖的package包/类
/**
 * Test of rank method, of class SumSquaresRatio.
 */
@Test
public void testRank() {
    System.out.println("rank");
    try {
        ArffParser arffParser = new ArffParser();
        arffParser.setResponseIndex(4);
        AttributeDataset iris = arffParser.parse(smile.data.parser.IOUtils.getTestDataFile("weka/iris.arff"));
        double[][] x = iris.toArray(new double[iris.size()][]);
        int[] y = iris.toArray(new int[iris.size()]);
        
        SumSquaresRatio ssr = new SumSquaresRatio();
        double[] ratio = ssr.rank(x, y);
        assertEquals(4, ratio.length);
        assertEquals(1.6226463, ratio[0], 1E-7);
        assertEquals(0.6444144, ratio[1], 1E-7);   
        assertEquals(16.0412833, ratio[2], 1E-7);   
        assertEquals(13.0520327, ratio[3], 1E-7); 
    } catch (Exception ex) {
        System.err.println(ex);
    }
}
 

开发者ID:takun2s,
项目名称:smile_1.5.0_java7,
代码行数:25,
代码来源:SumSquaresRatioTest.java

示例15: CoverTreeSpeedTest

点赞 3

import smile.data.AttributeDataset; //导入依赖的package包/类
public CoverTreeSpeedTest() {
    long start = System.currentTimeMillis();
    DelimitedTextParser parser = new DelimitedTextParser();
    parser.setResponseIndex(new NominalAttribute("class"), 0);
    try {
        AttributeDataset train = parser.parse("USPS Train", smile.data.parser.IOUtils.getTestDataFile("usps/zip.train"));
        AttributeDataset test = parser.parse("USPS Test", smile.data.parser.IOUtils.getTestDataFile("usps/zip.test"));

        x = train.toArray(new double[train.size()][]);
        testx = test.toArray(new double[test.size()][]);
    } catch (Exception ex) {
        System.err.println(ex);
    }

    double time = (System.currentTimeMillis() - start) / 1000.0;
    System.out.format("Loading data: %.2fs%n", time);

    start = System.currentTimeMillis();
    coverTree = new CoverTree<>(x, new EuclideanDistance());
    time = (System.currentTimeMillis() - start) / 1000.0;
    System.out.format("Building cover tree: %.2fs%n", time);
}
 

开发者ID:takun2s,
项目名称:smile_1.5.0_java7,
代码行数:23,
代码来源:CoverTreeSpeedTest.java

示例16: LSHTest

点赞 3

import smile.data.AttributeDataset; //导入依赖的package包/类
public LSHTest() {
    DelimitedTextParser parser = new DelimitedTextParser();
    parser.setResponseIndex(new NominalAttribute("class"), 0);
    try {
        AttributeDataset train = parser.parse("USPS Train", smile.data.parser.IOUtils.getTestDataFile("usps/zip.train"));
        AttributeDataset test = parser.parse("USPS Test", smile.data.parser.IOUtils.getTestDataFile("usps/zip.test"));

        x = train.toArray(new double[train.size()][]);
        testx = test.toArray(new double[test.size()][]);
    } catch (Exception ex) {
        System.err.println(ex);
    }

    naive = new LinearSearch<>(x, new EuclideanDistance());
    lsh = new LSH<>(x, x);
    /*
    lsh = new LSH<double[]>(256, 100, 3, 4.0);
    for (double[] xi : x) {
        lsh.put(xi, xi);
    }
     * 
     */
}
 

开发者ID:takun2s,
项目名称:smile_1.5.0_java7,
代码行数:24,
代码来源:LSHTest.java

示例17: testUSPSNystrom

点赞 3

import smile.data.AttributeDataset; //导入依赖的package包/类
/**
 * Test of learn method, of class SpectralClustering.
 */
@Test
public void testUSPSNystrom() {
    System.out.println("USPS Nystrom approximation");
    DelimitedTextParser parser = new DelimitedTextParser();
    parser.setResponseIndex(new NominalAttribute("class"), 0);
    try {
        AttributeDataset train = parser.parse("USPS Train", smile.data.parser.IOUtils.getTestDataFile("usps/zip.train"));

        double[][] x = train.toArray(new double[train.size()][]);
        int[] y = train.toArray(new int[train.size()]);
        
        SpectralClustering spectral = new SpectralClustering(x, 10, 100, 8.0);
        
        AdjustedRandIndex ari = new AdjustedRandIndex();
        RandIndex rand = new RandIndex();
        double r = rand.measure(y, spectral.getClusterLabel());
        double r2 = ari.measure(y, spectral.getClusterLabel());
        System.out.format("Training rand index = %.2f%%\tadjusted rand index = %.2f%%%n", 100.0 * r, 100.0 * r2);
        assertTrue(r > 0.8);
        assertTrue(r2 > 0.35);            
    } catch (Exception ex) {
        System.err.println(ex);
    }
}
 

开发者ID:takun2s,
项目名称:smile_1.5.0_java7,
代码行数:28,
代码来源:SpectralClusteringTest.java

示例18: testUSPS

点赞 2

import smile.data.AttributeDataset; //导入依赖的package包/类
/**
 * Test of learn method, of class RDA.
 */
@Test
public void testUSPS() {
    System.out.println("USPS");
    DelimitedTextParser parser = new DelimitedTextParser();
    parser.setResponseIndex(new NominalAttribute("class"), 0);
    try {
        AttributeDataset train = parser.parse("USPS Train", smile.data.parser.IOUtils.getTestDataFile("usps/zip.train"));
        AttributeDataset test = parser.parse("USPS Test", smile.data.parser.IOUtils.getTestDataFile("usps/zip.test"));

        double[][] x = train.toArray(new double[train.size()][]);
        int[] y = train.toArray(new int[train.size()]);
        double[][] testx = test.toArray(new double[test.size()][]);
        int[] testy = test.toArray(new int[test.size()]);
        
        RDA rda = new RDA(x, y, 0.7);
        
        int error = 0;
        for (int i = 0; i < testx.length; i++) {
            if (rda.predict(testx[i]) != testy[i]) {
                error++;
            }
        }

        System.out.format("USPS error rate = %.2f%%%n", 100.0 * error / testx.length);
        assertEquals(235, error);
    } catch (Exception ex) {
        System.err.println(ex);
    }
}
 

开发者ID:takun2s,
项目名称:smile_1.5.0_java7,
代码行数:33,
代码来源:RDATest.java

示例19: testWeather

点赞 2

import smile.data.AttributeDataset; //导入依赖的package包/类
/**
 * Test of learn method, of class AdaBoost.
 */
@Test
public void testWeather() {
    System.out.println("Weather");
    ArffParser arffParser = new ArffParser();
    arffParser.setResponseIndex(4);
    try {
        AttributeDataset weather = arffParser.parse(smile.data.parser.IOUtils.getTestDataFile("weka/weather.nominal.arff"));
        double[][] x = weather.toArray(new double[weather.size()][]);
        int[] y = weather.toArray(new int[weather.size()]);

        int n = x.length;
        LOOCV loocv = new LOOCV(n);
        int error = 0;
        for (int i = 0; i < n; i++) {
            double[][] trainx = Math.slice(x, loocv.train[i]);
            int[] trainy = Math.slice(y, loocv.train[i]);
            
            AdaBoost forest = new AdaBoost(weather.attributes(), trainx, trainy, 200, 4);
            if (y[loocv.test[i]] != forest.predict(x[loocv.test[i]]))
                error++;
        }

        System.out.println("AdaBoost error = " + error);
        assertEquals(3, error);
    } catch (Exception ex) {
        System.err.println(ex);
    }
}
 

开发者ID:takun2s,
项目名称:smile_1.5.0_java7,
代码行数:32,
代码来源:AdaBoostTest.java

示例20: testSegment

点赞 2

import smile.data.AttributeDataset; //导入依赖的package包/类
/**
 * Test of learn method, of class SVM.
 */
@Test
public void testSegment() {
    System.out.println("Segment");
    ArffParser parser = new ArffParser();
    parser.setResponseIndex(19);
    try {
        AttributeDataset train = parser.parse(smile.data.parser.IOUtils.getTestDataFile("weka/segment-challenge.arff"));
        AttributeDataset test = parser.parse(smile.data.parser.IOUtils.getTestDataFile("weka/segment-test.arff"));

        System.out.println(train.size() + " " + test.size());
        double[][] x = train.toArray(new double[0][]);
        int[] y = train.toArray(new int[0]);
        double[][] testx = test.toArray(new double[0][]);
        int[] testy = test.toArray(new int[0]);
        
        SVM<double[]> svm = new SVM<>(new GaussianKernel(8.0), 5.0, Math.max(y) + 1, SVM.Multiclass.ONE_VS_ALL);
        svm.learn(x, y);
        svm.finish();
        
        int error = 0;
        for (int i = 0; i < testx.length; i++) {
            if (svm.predict(testx[i]) != testy[i]) {
                error++;
            }
        }

        System.out.format("Segment error rate = %.2f%%%n", 100.0 * error / testx.length);
        assertTrue(error < 70);
    } catch (Exception ex) {
        ex.printStackTrace();
    }
}
 

开发者ID:takun2s,
项目名称:smile_1.5.0_java7,
代码行数:36,
代码来源:SVMTest.java

示例21: test

点赞 2

import smile.data.AttributeDataset; //导入依赖的package包/类
public void test(String dataset, String url, int response) {
    System.out.println(dataset);
    ArffParser parser = new ArffParser();
    parser.setResponseIndex(response);
    try {
        AttributeDataset data = parser.parse(smile.data.parser.IOUtils.getTestDataFile(url));
        double[] datay = data.toArray(new double[data.size()]);
        double[][] datax = data.toArray(new double[data.size()][]);
        
        int n = datax.length;
        int k = 10;

        CrossValidation cv = new CrossValidation(n, k);
        double rss = 0.0;
        double ad = 0.0;
        for (int i = 0; i < k; i++) {
            double[][] trainx = Math.slice(datax, cv.train[i]);
            double[] trainy = Math.slice(datay, cv.train[i]);
            double[][] testx = Math.slice(datax, cv.test[i]);
            double[] testy = Math.slice(datay, cv.test[i]);

            RegressionTree tree = new RegressionTree(data.attributes(), trainx, trainy, 20);

            for (int j = 0; j < testx.length; j++) {
                double r = testy[j] - tree.predict(testx[j]);
                rss += r * r;
                ad += Math.abs(r);
            }
        }

        System.out.format("10-CV RMSE = %.4f \t AbsoluteDeviation = %.4f%n", Math.sqrt(rss/n), ad/n);
     } catch (Exception ex) {
         System.err.println(ex);
     }
}
 

开发者ID:takun2s,
项目名称:smile_1.5.0_java7,
代码行数:36,
代码来源:RegressionTreeTest.java

示例22: testLearn

点赞 2

import smile.data.AttributeDataset; //导入依赖的package包/类
/**
 * Test of learn method, of class QDA.
 */
@Test
public void testLearn() {
    System.out.println("learn");
    ArffParser arffParser = new ArffParser();
    arffParser.setResponseIndex(4);
    try {
        AttributeDataset iris = arffParser.parse(smile.data.parser.IOUtils.getTestDataFile("weka/iris.arff"));
        double[][] x = iris.toArray(new double[iris.size()][]);
        int[] y = iris.toArray(new int[iris.size()]);

        int n = x.length;
        LOOCV loocv = new LOOCV(n);
        int error = 0;
        double[] posteriori = new double[3];
        for (int i = 0; i < n; i++) {
            double[][] trainx = Math.slice(x, loocv.train[i]);
            int[] trainy = Math.slice(y, loocv.train[i]);
            QDA qda = new QDA(trainx, trainy);

            if (y[loocv.test[i]] != qda.predict(x[loocv.test[i]], posteriori))
                error++;
            
            //System.out.println(posteriori[0]+"\t"+posteriori[1]+"\t"+posteriori[2]);
        }

        System.out.println("QDA error = " + error);
        assertEquals(4, error);
    } catch (Exception ex) {
        System.err.println(ex);
    }
}
 

开发者ID:takun2s,
项目名称:smile_1.5.0_java7,
代码行数:35,
代码来源:QDATest.java

示例23: testUSPS

点赞 2

import smile.data.AttributeDataset; //导入依赖的package包/类
/**
 * Test of learn method, of class LogisticRegression.
 */
@Test
public void testUSPS() {
    System.out.println("USPS");
    DelimitedTextParser parser = new DelimitedTextParser();
    parser.setResponseIndex(new NominalAttribute("class"), 0);
    try {
        AttributeDataset train = parser.parse("USPS Train", smile.data.parser.IOUtils.getTestDataFile("usps/zip.train"));
        AttributeDataset test = parser.parse("USPS Test", smile.data.parser.IOUtils.getTestDataFile("usps/zip.test"));

        double[][] x = train.toArray(new double[train.size()][]);
        int[] y = train.toArray(new int[train.size()]);
        double[][] testx = test.toArray(new double[test.size()][]);
        int[] testy = test.toArray(new int[test.size()]);
        
        LogisticRegression logit = new LogisticRegression(x, y, 0.3, 1E-3, 1000);
        
        int error = 0;
        for (int i = 0; i < testx.length; i++) {
            if (logit.predict(testx[i]) != testy[i]) {
                error++;
            }
        }

        System.out.format("USPS error rate = %.2f%%%n", 100.0 * error / testx.length);
        assertEquals(188, error);
    } catch (Exception ex) {
        System.err.println(ex);
    }
}
 

开发者ID:takun2s,
项目名称:smile_1.5.0_java7,
代码行数:33,
代码来源:LogisticRegressionTest.java

示例24: testIris

点赞 2

import smile.data.AttributeDataset; //导入依赖的package包/类
/**
 * Test of learn method, of class LogisticRegression.
 */
@Test
public void testIris() {
    System.out.println("Iris");
    ArffParser arffParser = new ArffParser();
    arffParser.setResponseIndex(4);
    try {
        AttributeDataset iris = arffParser.parse(smile.data.parser.IOUtils.getTestDataFile("weka/iris.arff"));
        double[][] x = iris.toArray(new double[iris.size()][]);
        int[] y = iris.toArray(new int[iris.size()]);

        int n = x.length;
        LOOCV loocv = new LOOCV(n);
        int error = 0;
        for (int i = 0; i < n; i++) {
            double[][] trainx = Math.slice(x, loocv.train[i]);
            int[] trainy = Math.slice(y, loocv.train[i]);
            LogisticRegression logit = new LogisticRegression(trainx, trainy);

            if (y[loocv.test[i]] != logit.predict(x[loocv.test[i]]))
                error++;
        }

        System.out.println("Logistic Regression error = " + error);
        assertEquals(3, error);
    } catch (Exception ex) {
        System.err.println(ex);
    }
}
 

开发者ID:takun2s,
项目名称:smile_1.5.0_java7,
代码行数:32,
代码来源:LogisticRegressionTest.java

示例25: getClassifier

点赞 2

import smile.data.AttributeDataset; //导入依赖的package包/类
@Override
public Classifier<double[]> getClassifier() throws IOException, ParseException {
  if (classifier != null) {
    return classifier;
  }
  DelimitedTextParser parser = new DelimitedTextParser();
  parser.setDelimiter(",");
  parser.setResponseIndex(new NominalAttribute("shading", new String[] { "0", "1", "2" }), 0);
  AttributeDataset dataset = parser.parse("data/train-out-" + getFileSuffix());
  double[][] vectors = dataset.toArray(new double[dataset.size()][]);
  int[] label = dataset.toArray(new int[dataset.size()]);
  classifier = KNN.learn(vectors, label, 5);
  return classifier;
}
 

开发者ID:tomwhite,
项目名称:set-game,
代码行数:15,
代码来源:FindCardShadingFeatures.java

示例26: testUSPS

点赞 2

import smile.data.AttributeDataset; //导入依赖的package包/类
/**
 * Test of learn method, of class RandomForest.
 */
@Test
public void testUSPS() {
    System.out.println("USPS");
    DelimitedTextParser parser = new DelimitedTextParser();
    parser.setResponseIndex(new NominalAttribute("class"), 0);
    try {
        AttributeDataset train = parser.parse("USPS Train", smile.data.parser.IOUtils.getTestDataFile("usps/zip.train"));
        AttributeDataset test = parser.parse("USPS Test", smile.data.parser.IOUtils.getTestDataFile("usps/zip.test"));

        double[][] x = train.toArray(new double[train.size()][]);
        int[] y = train.toArray(new int[train.size()]);
        double[][] testx = test.toArray(new double[test.size()][]);
        int[] testy = test.toArray(new int[test.size()]);
        
        RandomForest forest = new RandomForest(x, y, 200);
        
        int error = 0;
        for (int i = 0; i < testx.length; i++) {
            if (forest.predict(testx[i]) != testy[i]) {
                error++;
            }
        }

        System.out.println("USPS error = " + error);
        System.out.format("USPS OOB error rate = %.2f%%%n", 100.0 * forest.error());
        System.out.format("USPS error rate = %.2f%%%n", 100.0 * error / testx.length);
        assertTrue(error <= 225);
    } catch (Exception ex) {
        System.err.println(ex);
    }
}
 

开发者ID:takun2s,
项目名称:smile_1.5.0_java7,
代码行数:35,
代码来源:RandomForestTest.java

示例27: testUSPS

点赞 2

import smile.data.AttributeDataset; //导入依赖的package包/类
/**
 * Test of learn method, of class LDA.
 */
@Test
public void testUSPS() {
    System.out.println("USPS");
    DelimitedTextParser parser = new DelimitedTextParser();
    parser.setResponseIndex(new NominalAttribute("class"), 0);
    try {
        AttributeDataset train = parser.parse("USPS Train", smile.data.parser.IOUtils.getTestDataFile("usps/zip.train"));
        AttributeDataset test = parser.parse("USPS Test", smile.data.parser.IOUtils.getTestDataFile("usps/zip.test"));

        double[][] x = train.toArray(new double[train.size()][]);
        int[] y = train.toArray(new int[train.size()]);
        double[][] testx = test.toArray(new double[test.size()][]);
        int[] testy = test.toArray(new int[test.size()]);
        
        LDA lda = new LDA(x, y);
        
        int error = 0;
        for (int i = 0; i < testx.length; i++) {
            if (lda.predict(testx[i]) != testy[i]) {
                error++;
            }
        }

        System.out.format("USPS error rate = %.2f%%%n", 100.0 * error / testx.length);
        assertEquals(256, error);
    } catch (Exception ex) {
        System.err.println(ex);
    }
}
 

开发者ID:takun2s,
项目名称:smile_1.5.0_java7,
代码行数:33,
代码来源:LDATest.java

示例28: testIris

点赞 2

import smile.data.AttributeDataset; //导入依赖的package包/类
/**
 * Test of learn method, of class AdaBoost.
 */
@Test
public void testIris() {
    System.out.println("Iris");
    ArffParser arffParser = new ArffParser();
    arffParser.setResponseIndex(4);
    try {
        AttributeDataset iris = arffParser.parse(smile.data.parser.IOUtils.getTestDataFile("weka/iris.arff"));
        double[][] x = iris.toArray(new double[iris.size()][]);
        int[] y = iris.toArray(new int[iris.size()]);
        for (int i = 0; i < y.length; i++) {
            if (y[i] != 0) y[i] = 1;
        }

        int n = x.length;
        LOOCV loocv = new LOOCV(n);
        int error = 0;
        for (int i = 0; i < n; i++) {
            double[][] trainx = Math.slice(x, loocv.train[i]);
            int[] trainy = Math.slice(y, loocv.train[i]);
            
            AdaBoost forest = new AdaBoost(iris.attributes(), trainx, trainy, 200);
            if (y[loocv.test[i]] != forest.predict(x[loocv.test[i]]))
                error++;
        }

        System.out.println("AdaBoost error = " + error);
        assertEquals(0, error);
    } catch (Exception ex) {
        System.err.println(ex);
    }
}
 

开发者ID:takun2s,
项目名称:smile_1.5.0_java7,
代码行数:35,
代码来源:AdaBoostTest.java

示例29: testUSPS

点赞 2

import smile.data.AttributeDataset; //导入依赖的package包/类
/**
 * Test of learn method, of class GradientTreeBoost.
 */
@Test
public void testUSPS() {
    System.out.println("USPS");
    DelimitedTextParser parser = new DelimitedTextParser();
    parser.setResponseIndex(new NominalAttribute("class"), 0);
    try {
        AttributeDataset train = parser.parse("USPS Train", smile.data.parser.IOUtils.getTestDataFile("usps/zip.train"));
        AttributeDataset test = parser.parse("USPS Test", smile.data.parser.IOUtils.getTestDataFile("usps/zip.test"));

        double[][] x = train.toArray(new double[train.size()][]);
        int[] y = train.toArray(new int[train.size()]);
        double[][] testx = test.toArray(new double[test.size()][]);
        int[] testy = test.toArray(new int[test.size()]);
        
        GradientTreeBoost boost = new GradientTreeBoost(train.attributes(), x, y, 100);
        
        int error = 0;
        for (int i = 0; i < testx.length; i++) {
            if (boost.predict(testx[i]) != testy[i]) {
                error++;
            }
        }

        System.out.format("Gradient Tree Boost error rate = %.2f%%%n", 100.0 * error / testx.length);

        double[] accuracy = boost.test(testx, testy);
        for (int i = 1; i <= accuracy.length; i++) {
            System.out.format("%d trees accuracy = %.2f%%%n", i, 100.0 * accuracy[i-1]);
        }
        
        double[] importance = boost.importance();
        int[] index = QuickSort.sort(importance);
        for (int i = importance.length; i-- > 0; ) {
            System.out.format("%s importance is %.4f%n", train.attributes()[index[i]], importance[i]);
        }
    } catch (Exception ex) {
        System.err.println(ex);
    }
}
 

开发者ID:takun2s,
项目名称:smile_1.5.0_java7,
代码行数:43,
代码来源:GradientTreeBoostTest.java

示例30: testIris2

点赞 2

import smile.data.AttributeDataset; //导入依赖的package包/类
/**
 * Test of predict method, of class GradientTreeBoost.
 */
@Test
public void testIris2() {
    System.out.println("Iris binary");
    ArffParser arffParser = new ArffParser();
    arffParser.setResponseIndex(4);
    try {
        AttributeDataset iris = arffParser.parse(smile.data.parser.IOUtils.getTestDataFile("weka/iris.arff"));
        double[][] x = iris.toArray(new double[iris.size()][]);
        int[] y = iris.toArray(new int[iris.size()]);

        for (int i = 0; i < y.length; i++) {
            if (y[i] == 2) {
                y[i] = 1;
            } else {
                y[i] = 0;
            }
        }

        int n = x.length;
        LOOCV loocv = new LOOCV(n);
        int error = 0;
        for (int i = 0; i < n; i++) {
            double[][] trainx = Math.slice(x, loocv.train[i]);
            int[] trainy = Math.slice(y, loocv.train[i]);
            GradientTreeBoost boost = new GradientTreeBoost(iris.attributes(), trainx, trainy, 100);

            if (y[loocv.test[i]] != boost.predict(x[loocv.test[i]]))
                error++;
        }

        System.out.println("Gradient Tree Boost error = " + error);
        //assertEquals(6, error);
    } catch (Exception ex) {
        System.err.println(ex);
    }
}
 

开发者ID:takun2s,
项目名称:smile_1.5.0_java7,
代码行数:40,
代码来源:GradientTreeBoostTest.java

示例31: testParse

点赞 2

import smile.data.AttributeDataset; //导入依赖的package包/类
/**
 * Test of parse method, of class DelimitedTextParser.
 */
@Test
public void testParse() throws Exception {
    System.out.println("parse");
    try {
        DelimitedTextParser parser = new DelimitedTextParser();
        parser.setResponseIndex(new NominalAttribute("class"), 0);
        AttributeDataset usps = parser.parse("USPS Train", smile.data.parser.IOUtils.getTestDataFile("usps/zip.train"));
        double[][] x = usps.toArray(new double[usps.size()][]);
        int[] y = usps.toArray(new int[usps.size()]);
        
        assertEquals(Attribute.Type.NOMINAL, usps.responseAttribute().getType());
        for (Attribute attribute : usps.attributes()) {
            assertEquals(Attribute.Type.NUMERIC, attribute.getType());
        }

        assertEquals(7291, usps.size());
        assertEquals(256, usps.attributes().length);

        assertEquals("6", usps.responseAttribute().toString(y[0]));
        assertEquals("5", usps.responseAttribute().toString(y[1]));
        assertEquals("4", usps.responseAttribute().toString(y[2]));
        assertEquals(-1.0000, x[0][6], 1E-7);
        assertEquals(-0.6310, x[0][7], 1E-7);
        assertEquals(0.8620, x[0][8], 1E-7);

        assertEquals("1", usps.responseAttribute().toString(y[7290]));
        assertEquals(-1.0000, x[7290][4], 1E-7);
        assertEquals(-0.1080, x[7290][5], 1E-7);
        assertEquals(1.0000, x[7290][6], 1E-7);
    } catch (Exception ex) {
        System.err.println(ex);
    }
}
 

开发者ID:takun2s,
项目名称:smile_1.5.0_java7,
代码行数:37,
代码来源:DelimitedTextParserTest.java

示例32: test

点赞 2

import smile.data.AttributeDataset; //导入依赖的package包/类
public void test(GradientTreeBoost.Loss loss, String dataset, String url, int response) {
    System.out.println(dataset + "\t" + loss);
    ArffParser parser = new ArffParser();
    parser.setResponseIndex(response);
    try {
        AttributeDataset data = parser.parse(smile.data.parser.IOUtils.getTestDataFile(url));
        double[] datay = data.toArray(new double[data.size()]);
        double[][] datax = data.toArray(new double[data.size()][]);
        
        int n = datax.length;
        int k = 10;

        CrossValidation cv = new CrossValidation(n, k);
        double rss = 0.0;
        double ad = 0.0;
        for (int i = 0; i < k; i++) {
            double[][] trainx = Math.slice(datax, cv.train[i]);
            double[] trainy = Math.slice(datay, cv.train[i]);
            double[][] testx = Math.slice(datax, cv.test[i]);
            double[] testy = Math.slice(datay, cv.test[i]);

            GradientTreeBoost boost = new GradientTreeBoost(data.attributes(), trainx, trainy, loss, 100, 6, 0.05, 0.7);

            for (int j = 0; j < testx.length; j++) {
                double r = testy[j] - boost.predict(testx[j]);
                ad += Math.abs(r);
                rss += r * r;
            }
        }

        System.out.format("10-CV RMSE = %.4f \t AbsoluteDeviation = %.4f%n", Math.sqrt(rss/n), ad/n);
     } catch (Exception ex) {
         System.err.println(ex);
     }
}
 

开发者ID:takun2s,
项目名称:smile_1.5.0_java7,
代码行数:36,
代码来源:GradientTreeBoostTest.java

示例33: testParse

点赞 2

import smile.data.AttributeDataset; //导入依赖的package包/类
/**
 * Test of parse method, of class GCTParser.
 */
@Test
public void testParse() throws Exception {
    System.out.println("parse");
    GCTParser parser = new GCTParser();
    try {
        AttributeDataset data = parser.parse("GCT", smile.data.parser.IOUtils.getTestDataFile("microarray/allaml.dataset.gct"));
        
        double[][] x = data.toArray(new double[data.size()][]);
        String[] id = data.toArray(new String[data.size()]);
        
        for (Attribute attribute : data.attributes()) {
            assertEquals(Attribute.Type.NUMERIC, attribute.getType());
            System.out.println(attribute.getName());
        }

        assertEquals(12564, data.size());
        assertEquals(48, data.attributes().length);

        assertEquals("AFFX-MurIL2_at", id[0]);
        assertEquals(-161.8, x[0][0], 1E-7);
        assertEquals(-231.0, x[0][1], 1E-7);
        assertEquals(-279.0, x[0][2], 1E-7);

        assertEquals("128_at", id[12563]);
        assertEquals(95.0, x[12563][45], 1E-7);
        assertEquals(108.0, x[12563][46], 1E-7);
        assertEquals(346.0, x[12563][47], 1E-7);
    } catch (Exception ex) {
        System.err.println(ex);
    }
}
 

开发者ID:takun2s,
项目名称:smile_1.5.0_java7,
代码行数:35,
代码来源:GCTParserTest.java

示例34: testUSPS

点赞 2

import smile.data.AttributeDataset; //导入依赖的package包/类
/**
 * Test of learn method, of class AdaBoost.
 */
@Test
public void testUSPS() {
    System.out.println("USPS");
    DelimitedTextParser parser = new DelimitedTextParser();
    parser.setResponseIndex(new NominalAttribute("class"), 0);
    try {
        AttributeDataset train = parser.parse("USPS Train", smile.data.parser.IOUtils.getTestDataFile("usps/zip.train"));
        AttributeDataset test = parser.parse("USPS Test", smile.data.parser.IOUtils.getTestDataFile("usps/zip.test"));

        double[][] x = train.toArray(new double[train.size()][]);
        int[] y = train.toArray(new int[train.size()]);
        double[][] testx = test.toArray(new double[test.size()][]);
        int[] testy = test.toArray(new int[test.size()]);

        for (int i = 0; i < y.length; i++) {
            if (y[i] != 0) y[i] = 1;
        }
        for (int i = 0; i < testy.length; i++) {
            if (testy[i] != 0) testy[i] = 1;
        }
        
        AdaBoost forest = new AdaBoost(x, y, 100, 6);
        
        int error = 0;
        for (int i = 0; i < testx.length; i++) {
            if (forest.predict(testx[i]) != testy[i]) {
                error++;
            }
        }

        System.out.println("AdaBoost error = " + error);
        System.out.format("USPS error rate = %.2f%%%n", 100.0 * error / testx.length);
        assertTrue(error <= 25);
    } catch (Exception ex) {
        System.err.println(ex);
    }
}
 

开发者ID:takun2s,
项目名称:smile_1.5.0_java7,
代码行数:41,
代码来源:AdaBoostTest.java

示例35: testParse

点赞 2

import smile.data.AttributeDataset; //导入依赖的package包/类
/**
 * Test of parse method, of class PCLParser.
 */
@Test
public void testParse() throws Exception {
    System.out.println("parse");
    PCLParser parser = new PCLParser();
    try {
        AttributeDataset data = parser.parse("PCL", smile.data.parser.IOUtils.getTestDataFile("microarray/Dunham2002.pcl"));
        
        double[][] x = data.toArray(new double[data.size()][]);
        String[] id = data.toArray(new String[data.size()]);
        
        for (Attribute attribute : data.attributes()) {
            assertEquals(Attribute.Type.NUMERIC, attribute.getType());
            System.out.println(attribute.getName());
        }

        assertEquals(6694, data.size());
        assertEquals(16, data.attributes().length);

        assertEquals("YKR005C", id[0]);
        assertEquals(-0.43, x[0][0], 1E-7);
        assertEquals(-0.47, x[0][1], 1E-7);
        assertEquals(-0.39, x[0][2], 1E-7);

        assertEquals("YKR004C", id[6693]);
        assertEquals(0.03, x[6693][13], 1E-7);
        assertEquals(-0.53, x[6693][14], 1E-7);
        assertEquals(0.3, x[6693][15], 1E-7);
    } catch (Exception ex) {
        System.err.println(ex);
    }
}
 

开发者ID:takun2s,
项目名称:smile_1.5.0_java7,
代码行数:35,
代码来源:PCLParserTest.java

示例36: testParseWeather

点赞 2

import smile.data.AttributeDataset; //导入依赖的package包/类
/**
 * Test of parse method, of class ArffParser.
 */
@Test
public void testParseWeather() throws Exception {
    System.out.println("weather");
    try {
        ArffParser arffParser = new ArffParser();
        arffParser.setResponseIndex(4);
        AttributeDataset weather = arffParser.parse(smile.data.parser.IOUtils.getTestDataFile("weka/weather.nominal.arff"));
        double[][] x = weather.toArray(new double[weather.size()][]);
        int[] y = weather.toArray(new int[weather.size()]);
        
        assertEquals(Attribute.Type.NOMINAL, weather.responseAttribute().getType());
        for (Attribute attribute : weather.attributes()) {
            assertEquals(Attribute.Type.NOMINAL, attribute.getType());
        }
        
        assertEquals(14, weather.size());
        assertEquals(4, weather.attributes().length);
        assertEquals("no", weather.responseAttribute().toString(y[0]));
        assertEquals("no", weather.responseAttribute().toString(y[1]));
        assertEquals("yes", weather.responseAttribute().toString(y[2]));
        assertEquals("sunny", weather.attributes()[0].toString(x[0][0]));
        assertEquals("hot", weather.attributes()[1].toString(x[0][1]));
        assertEquals("high", weather.attributes()[2].toString(x[0][2]));
        assertEquals("FALSE", weather.attributes()[3].toString(x[0][3]));

        assertEquals("no", weather.responseAttribute().toString(y[13]));
        assertEquals("rainy", weather.attributes()[0].toString(x[13][0]));
        assertEquals("mild", weather.attributes()[1].toString(x[13][1]));
        assertEquals("high", weather.attributes()[2].toString(x[13][2]));
        assertEquals("TRUE", weather.attributes()[3].toString(x[13][3]));
    } catch (Exception ex) {
        System.err.println(ex);
    }
}
 

开发者ID:takun2s,
项目名称:smile_1.5.0_java7,
代码行数:38,
代码来源:ArffParserTest.java

示例37: testParseIris

点赞 2

import smile.data.AttributeDataset; //导入依赖的package包/类
/**
 * Test of parse method, of class ArffParser.
 */
@Test
public void testParseIris() throws Exception {
    System.out.println("iris");
    try {
        ArffParser arffParser = new ArffParser();
        arffParser.setResponseIndex(4);
        AttributeDataset iris = arffParser.parse(smile.data.parser.IOUtils.getTestDataFile("weka/iris.arff"));
        double[][] x = iris.toArray(new double[iris.size()][]);
        int[] y = iris.toArray(new int[iris.size()]);
        
        assertEquals(Attribute.Type.NOMINAL, iris.responseAttribute().getType());
        for (Attribute attribute : iris.attributes()) {
            assertEquals(Attribute.Type.NUMERIC, attribute.getType());
        }
        
        assertEquals(150, iris.size());
        assertEquals(4, iris.attributes().length);
        assertEquals("Iris-setosa", iris.responseAttribute().toString(y[0]));
        assertEquals("Iris-setosa", iris.responseAttribute().toString(y[1]));
        assertEquals("Iris-setosa", iris.responseAttribute().toString(y[2]));
        assertEquals(5.1, x[0][0], 1E-7);
        assertEquals(3.5, x[0][1], 1E-7);
        assertEquals(1.4, x[0][2], 1E-7);
        assertEquals(0.2, x[0][3], 1E-7);

        assertEquals("Iris-virginica", iris.responseAttribute().toString(y[149]));
        assertEquals(5.9, x[149][0], 1E-7);
        assertEquals(3.0, x[149][1], 1E-7);
        assertEquals(5.1, x[149][2], 1E-7);
        assertEquals(1.8, x[149][3], 1E-7);
    } catch (Exception ex) {
        System.err.println(ex);
    }
}
 

开发者ID:takun2s,
项目名称:smile_1.5.0_java7,
代码行数:38,
代码来源:ArffParserTest.java

示例38: testParseSparse

点赞 2

import smile.data.AttributeDataset; //导入依赖的package包/类
/**
 * Test of parse method, of class ArffParser.
 */
@Test
public void testParseSparse() throws Exception {
    System.out.println("sparse");
    try {
        ArffParser arffParser = new ArffParser();
        AttributeDataset sparse = arffParser.parse(smile.data.parser.IOUtils.getTestDataFile("weka/sparse.arff"));
        double[][] x = sparse.toArray(new double[sparse.size()][]);
                    
        assertEquals(2, sparse.size());
        assertEquals(5, sparse.attributes().length);
        
        assertEquals(0.0, x[0][0], 1E-7);
        assertEquals(2.0, x[0][1], 1E-7);
        assertEquals(0.0, x[0][2], 1E-7);
        assertEquals(3.0, x[0][3], 1E-7);
        assertEquals(0.0, x[0][4], 1E-7);
        
        assertEquals(0.0, x[1][0], 1E-7);
        assertEquals(0.0, x[1][1], 1E-7);
        assertEquals(1.0, x[1][2], 1E-7);
        assertEquals(0.0, x[1][3], 1E-7);
        assertEquals(1.0, x[1][4], 1E-7);
    } catch (Exception ex) {
        System.err.println(ex);
    }
}
 

开发者ID:takun2s,
项目名称:smile_1.5.0_java7,
代码行数:30,
代码来源:ArffParserTest.java

示例39: testKPCAThreshold

点赞 2

import smile.data.AttributeDataset; //导入依赖的package包/类
/**
     * Test of learn method, of class PCA.
     */
    @Test
    public void testKPCAThreshold() {
        System.out.println("learn threshold");
        ArffParser arffParser = new ArffParser();
        arffParser.setResponseIndex(4);
        try {
            AttributeDataset iris = arffParser.parse(smile.data.parser.IOUtils.getTestDataFile("weka/iris.arff"));

            double[][] x = iris.toArray(new double[iris.size()][]);
            KPCA<double[]> kpca = new KPCA(x, new GaussianKernel(Math.sqrt(2.5)), 1E-4);
            assertTrue(Math.equals(latent, kpca.getVariances(), 1E-3));
            double[][] points = kpca.project(x);
            points[0] = kpca.project(x[0]);
            assertTrue(Math.equals(points, kpca.getCoordinates(), 1E-7));
/*
            for (int j = 0; j < points[0].length; j++) {
                double sign = Math.signum(points[0][j] / scores[0][j]);
                for (int i = 0; i < points.length; i++) {
                    points[i][j] *= sign;
                }
            }

            assertTrue(Math.equals(scores, points, 1E-1));
 */
        } catch (Exception ex) {
            System.err.println(ex);
        }
    }
 

开发者ID:takun2s,
项目名称:smile_1.5.0_java7,
代码行数:32,
代码来源:KPCATest.java

示例40: testUSPS

点赞 2

import smile.data.AttributeDataset; //导入依赖的package包/类
/**
 * Test of learn method, of class DecisionTree.
 */
@Test
public void testUSPS() {
    System.out.println("USPS");
    DelimitedTextParser parser = new DelimitedTextParser();
    parser.setResponseIndex(new NominalAttribute("class"), 0);
    try {
        AttributeDataset train = parser.parse("USPS Train", smile.data.parser.IOUtils.getTestDataFile("usps/zip.train"));
        AttributeDataset test = parser.parse("USPS Test", smile.data.parser.IOUtils.getTestDataFile("usps/zip.test"));

        double[][] x = train.toArray(new double[train.size()][]);
        int[] y = train.toArray(new int[train.size()]);
        double[][] testx = test.toArray(new double[test.size()][]);
        int[] testy = test.toArray(new int[test.size()]);
        
        DecisionTree tree = new DecisionTree(x, y, 350, DecisionTree.SplitRule.ENTROPY);
        
        int error = 0;
        for (int i = 0; i < testx.length; i++) {
            if (tree.predict(testx[i]) != testy[i]) {
                error++;
            }
        }

        System.out.format("USPS error rate = %.2f%%%n", 100.0 * error / testx.length);
        assertEquals(328, error);
    } catch (Exception ex) {
        System.err.println(ex);
    }
}
 

开发者ID:takun2s,
项目名称:smile_1.5.0_java7,
代码行数:33,
代码来源:DecisionTreeTest.java


版权声明:本文转自网络文章,转载此文章仅为分享知识,如有侵权,请联系管理员进行删除。
喜欢 (0)