From 5b77ebb57ba6edf1f8f4f8e83abef4f9f1ae6175 Mon Sep 17 00:00:00 2001 From: Ludovic Henry Date: Tue, 27 Apr 2021 14:00:59 -0500 Subject: [PATCH] [SPARK-35150][ML] Accelerate fallback BLAS with dev.ludovic.netlib ### What changes were proposed in this pull request? Following https://github.com/apache/spark/pull/30810, I've continued looking for ways to accelerate the usage of BLAS in Spark. With this PR, I integrate work done in the [`dev.ludovic.netlib`](https://github.com/luhenry/netlib/) Maven package. The `dev.ludovic.netlib` library wraps the original `com.github.fommil.netlib` library and focus on accelerating the linear algebra routines in use in Spark. When running the `org.apache.spark.ml.linalg.BLASBenchmark` benchmarking suite, I get the results at [1] on an Intel machine. Moreover, this library is thoroughly tested to return the exact same results as the reference implementation. Under the hood, it reimplements the necessary algorithms in pure autovectorization-friendly Java 8, as well as takes advantage of the Vector API and Foreign Linker API introduced in JDK 16 when available. A table summarising which version gets loaded in which case: ``` | | BLAS.nativeBLAS | BLAS.javaBLAS | | --------------------- | -------------------------------------------------- | -------------------------------------------------- | | with -Pnetlib-lgpl | 1. dev.ludovic.netlib.blas.NetlibNativeBLAS, a | 1. dev.ludovic.netlib.blas.VectorizedBLAS | | | wrapper for com.github.fommil:all | (JDK16+, relies on the Vector API, requires | | | 2. dev.ludovic.netlib.blas.ForeignBLAS (JDK16+, | `--add-modules=jdk.incubator.vector` on JDK16) | | | relies on the Foreign Linker API, requires | 2. dev.ludovic.netlib.blas.Java11BLAS (JDK11+) | | | `--add-modules=jdk.incubator.foreign | 3. dev.ludovic.netlib.blas.JavaBLAS | | | -Dforeign.restricted=warn`) | 4. dev.ludovic.netlib.blas.NetlibF2jBLAS, a | | | 3. fails to load, falls back to BLAS.javaBLAS in | wrapper for com.github.fommil:core | | | org.apache.spark.ml.linalg.BLAS | | | --------------------- | -------------------------------------------------- | -------------------------------------------------- | | without -Pnetlib-lgpl | 1. dev.ludovic.netlib.blas.ForeignBLAS (JDK16+, | 1. dev.ludovic.netlib.blas.VectorizedBLAS | | | relies on the Foreign Linker API, requires | (JDK16+, relies on the Vector API, requires | | | `--add-modules=jdk.incubator.foreign | `--add-modules=jdk.incubator.vector` on JDK16) | | | -Dforeign.restricted=warn`) | 2. dev.ludovic.netlib.blas.Java11BLAS (JDK11+) | | | 2. fails to load, falls back to BLAS.javaBLAS in | 3. dev.ludovic.netlib.blas.JavaBLAS | | | org.apache.spark.ml.linalg.BLAS | 4. dev.ludovic.netlib.blas.NetlibF2jBLAS, a | | | | wrapper for com.github.fommil:core | | --------------------- | -------------------------------------------------- | -------------------------------------------------- | ``` ### Why are the changes needed? Accelerates linear algebra operations when the pure-java fallback method is in use. Transparently falls back to native implementation (OpenBLAS, MKL) when available. ### Does this PR introduce _any_ user-facing change? No, all changes are transparent to the user. ### How was this patch tested? The `dev.ludovic.netlib` library has its own test suite [2]. It has also been validated by running the Spark test suite and benchmarking suite. [1] Results for `org.apache.spark.ml.linalg.BLASBenchmark`: #### JDK8: ``` [info] OpenJDK 64-Bit Server VM 1.8.0_292-b10 on Linux 5.8.0-50-generic [info] Intel(R) Xeon(R) E-2276G CPU 3.80GHz [info] [info] f2jBLAS = dev.ludovic.netlib.blas.NetlibF2jBLAS [info] javaBLAS = dev.ludovic.netlib.blas.Java8BLAS [info] nativeBLAS = dev.ludovic.netlib.blas.Java8BLAS [info] [info] daxpy: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative [info] ------------------------------------------------------------------------------------------------------------------------ [info] f2j 223 232 8 448.0 2.2 1.0X [info] java 221 228 7 453.0 2.2 1.0X [info] [info] saxpy: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative [info] ------------------------------------------------------------------------------------------------------------------------ [info] f2j 122 128 4 821.2 1.2 1.0X [info] java 122 128 4 822.3 1.2 1.0X [info] [info] ddot: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative [info] ------------------------------------------------------------------------------------------------------------------------ [info] f2j 109 112 2 921.4 1.1 1.0X [info] java 70 74 3 1423.5 0.7 1.5X [info] [info] sdot: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative [info] ------------------------------------------------------------------------------------------------------------------------ [info] f2j 96 98 2 1046.1 1.0 1.0X [info] java 47 49 2 2121.7 0.5 2.0X [info] [info] dscal: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative [info] ------------------------------------------------------------------------------------------------------------------------ [info] f2j 184 195 8 544.3 1.8 1.0X [info] java 185 196 7 539.5 1.9 1.0X [info] [info] sscal: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative [info] ------------------------------------------------------------------------------------------------------------------------ [info] f2j 99 104 4 1011.9 1.0 1.0X [info] java 99 104 4 1010.4 1.0 1.0X [info] [info] dspmv[U]: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative [info] ------------------------------------------------------------------------------------------------------------------------ [info] f2j 1 1 0 947.2 1.1 1.0X [info] java 0 0 0 1584.8 0.6 1.7X [info] [info] dspr[U]: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative [info] ------------------------------------------------------------------------------------------------------------------------ [info] f2j 1 1 0 867.4 1.2 1.0X [info] java 1 1 0 865.0 1.2 1.0X [info] [info] dsyr[U]: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative [info] ------------------------------------------------------------------------------------------------------------------------ [info] f2j 1 1 0 485.9 2.1 1.0X [info] java 1 1 0 486.8 2.1 1.0X [info] [info] dgemv[N]: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative [info] ------------------------------------------------------------------------------------------------------------------------ [info] f2j 1 1 0 1843.0 0.5 1.0X [info] java 0 0 0 2690.6 0.4 1.5X [info] [info] dgemv[T]: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative [info] ------------------------------------------------------------------------------------------------------------------------ [info] f2j 1 1 0 1214.7 0.8 1.0X [info] java 0 0 0 2536.8 0.4 2.1X [info] [info] sgemv[N]: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative [info] ------------------------------------------------------------------------------------------------------------------------ [info] f2j 1 1 0 1895.9 0.5 1.0X [info] java 0 0 0 2961.1 0.3 1.6X [info] [info] sgemv[T]: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative [info] ------------------------------------------------------------------------------------------------------------------------ [info] f2j 1 1 0 1223.4 0.8 1.0X [info] java 0 0 0 3091.4 0.3 2.5X [info] [info] dgemm[N,N]: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative [info] ------------------------------------------------------------------------------------------------------------------------ [info] f2j 560 575 20 1787.1 0.6 1.0X [info] java 226 232 5 4432.4 0.2 2.5X [info] [info] dgemm[N,T]: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative [info] ------------------------------------------------------------------------------------------------------------------------ [info] f2j 570 586 23 1755.2 0.6 1.0X [info] java 227 232 4 4410.1 0.2 2.5X [info] [info] dgemm[T,N]: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative [info] ------------------------------------------------------------------------------------------------------------------------ [info] f2j 863 879 17 1158.4 0.9 1.0X [info] java 227 231 3 4407.9 0.2 3.8X [info] [info] dgemm[T,T]: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative [info] ------------------------------------------------------------------------------------------------------------------------ [info] f2j 1282 1305 23 780.0 1.3 1.0X [info] java 227 232 4 4413.4 0.2 5.7X [info] [info] sgemm[N,N]: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative [info] ------------------------------------------------------------------------------------------------------------------------ [info] f2j 538 548 8 1858.6 0.5 1.0X [info] java 221 226 3 4521.1 0.2 2.4X [info] [info] sgemm[N,T]: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative [info] ------------------------------------------------------------------------------------------------------------------------ [info] f2j 549 558 10 1819.9 0.5 1.0X [info] java 222 229 7 4503.5 0.2 2.5X [info] [info] sgemm[T,N]: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative [info] ------------------------------------------------------------------------------------------------------------------------ [info] f2j 838 852 12 1193.0 0.8 1.0X [info] java 222 229 5 4500.5 0.2 3.8X [info] [info] sgemm[T,T]: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative [info] ------------------------------------------------------------------------------------------------------------------------ [info] f2j 905 919 18 1104.8 0.9 1.0X [info] java 221 228 5 4521.3 0.2 4.1X ``` #### JDK11: ``` [info] OpenJDK 64-Bit Server VM 11.0.11+9-LTS on Linux 5.8.0-50-generic [info] Intel(R) Xeon(R) E-2276G CPU 3.80GHz [info] [info] f2jBLAS = dev.ludovic.netlib.blas.NetlibF2jBLAS [info] javaBLAS = dev.ludovic.netlib.blas.Java11BLAS [info] nativeBLAS = dev.ludovic.netlib.blas.Java11BLAS [info] [info] daxpy: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative [info] ------------------------------------------------------------------------------------------------------------------------ [info] f2j 195 204 10 512.7 2.0 1.0X [info] java 195 202 7 512.4 2.0 1.0X [info] [info] saxpy: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative [info] ------------------------------------------------------------------------------------------------------------------------ [info] f2j 108 113 4 923.3 1.1 1.0X [info] java 102 107 4 984.4 1.0 1.1X [info] [info] ddot: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative [info] ------------------------------------------------------------------------------------------------------------------------ [info] f2j 107 110 3 938.1 1.1 1.0X [info] java 69 72 3 1447.1 0.7 1.5X [info] [info] sdot: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative [info] ------------------------------------------------------------------------------------------------------------------------ [info] f2j 96 98 2 1046.5 1.0 1.0X [info] java 43 45 2 2317.1 0.4 2.2X [info] [info] dscal: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative [info] ------------------------------------------------------------------------------------------------------------------------ [info] f2j 155 168 8 644.2 1.6 1.0X [info] java 158 169 8 632.8 1.6 1.0X [info] [info] sscal: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative [info] ------------------------------------------------------------------------------------------------------------------------ [info] f2j 85 90 4 1178.1 0.8 1.0X [info] java 86 90 4 1167.7 0.9 1.0X [info] [info] dspmv[U]: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative [info] ------------------------------------------------------------------------------------------------------------------------ [info] f2j 0 0 0 1182.1 0.8 1.0X [info] java 0 0 0 1432.1 0.7 1.2X [info] [info] dspr[U]: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative [info] ------------------------------------------------------------------------------------------------------------------------ [info] f2j 1 1 0 898.7 1.1 1.0X [info] java 1 1 0 891.5 1.1 1.0X [info] [info] dsyr[U]: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative [info] ------------------------------------------------------------------------------------------------------------------------ [info] f2j 1 1 0 495.4 2.0 1.0X [info] java 1 1 0 495.7 2.0 1.0X [info] [info] dgemv[N]: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative [info] ------------------------------------------------------------------------------------------------------------------------ [info] f2j 0 0 0 2271.6 0.4 1.0X [info] java 0 0 0 3648.1 0.3 1.6X [info] [info] dgemv[T]: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative [info] ------------------------------------------------------------------------------------------------------------------------ [info] f2j 1 1 0 1229.3 0.8 1.0X [info] java 0 0 0 2711.3 0.4 2.2X [info] [info] sgemv[N]: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative [info] ------------------------------------------------------------------------------------------------------------------------ [info] f2j 0 0 0 2677.5 0.4 1.0X [info] java 0 0 0 3288.2 0.3 1.2X [info] [info] sgemv[T]: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative [info] ------------------------------------------------------------------------------------------------------------------------ [info] f2j 1 1 0 1233.0 0.8 1.0X [info] java 0 0 0 2766.3 0.4 2.2X [info] [info] dgemm[N,N]: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative [info] ------------------------------------------------------------------------------------------------------------------------ [info] f2j 520 536 16 1923.6 0.5 1.0X [info] java 214 221 7 4669.5 0.2 2.4X [info] [info] dgemm[N,T]: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative [info] ------------------------------------------------------------------------------------------------------------------------ [info] f2j 593 612 17 1686.5 0.6 1.0X [info] java 215 219 3 4643.3 0.2 2.8X [info] [info] dgemm[T,N]: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative [info] ------------------------------------------------------------------------------------------------------------------------ [info] f2j 853 870 16 1172.8 0.9 1.0X [info] java 215 218 3 4659.7 0.2 4.0X [info] [info] dgemm[T,T]: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative [info] ------------------------------------------------------------------------------------------------------------------------ [info] f2j 1350 1370 23 740.8 1.3 1.0X [info] java 215 219 4 4656.6 0.2 6.3X [info] [info] sgemm[N,N]: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative [info] ------------------------------------------------------------------------------------------------------------------------ [info] f2j 460 468 6 2173.2 0.5 1.0X [info] java 210 213 2 4752.7 0.2 2.2X [info] [info] sgemm[N,T]: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative [info] ------------------------------------------------------------------------------------------------------------------------ [info] f2j 535 544 8 1869.3 0.5 1.0X [info] java 210 215 5 4761.8 0.2 2.5X [info] [info] sgemm[T,N]: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative [info] ------------------------------------------------------------------------------------------------------------------------ [info] f2j 843 853 11 1186.8 0.8 1.0X [info] java 209 214 4 4793.4 0.2 4.0X [info] [info] sgemm[T,T]: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative [info] ------------------------------------------------------------------------------------------------------------------------ [info] f2j 891 904 15 1122.0 0.9 1.0X [info] java 209 214 4 4777.2 0.2 4.3X ``` #### JDK16: ``` [info] OpenJDK 64-Bit Server VM 16+36 on Linux 5.8.0-50-generic [info] Intel(R) Xeon(R) E-2276G CPU 3.80GHz [info] [info] f2jBLAS = dev.ludovic.netlib.blas.NetlibF2jBLAS [info] javaBLAS = dev.ludovic.netlib.blas.VectorizedBLAS [info] nativeBLAS = dev.ludovic.netlib.blas.VectorizedBLAS [info] [info] daxpy: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative [info] ------------------------------------------------------------------------------------------------------------------------ [info] f2j 194 199 7 515.7 1.9 1.0X [info] java 181 186 3 551.1 1.8 1.1X [info] [info] saxpy: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative [info] ------------------------------------------------------------------------------------------------------------------------ [info] f2j 109 115 4 915.0 1.1 1.0X [info] java 88 92 3 1138.8 0.9 1.2X [info] [info] ddot: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative [info] ------------------------------------------------------------------------------------------------------------------------ [info] f2j 108 110 2 922.6 1.1 1.0X [info] java 54 56 2 1839.2 0.5 2.0X [info] [info] sdot: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative [info] ------------------------------------------------------------------------------------------------------------------------ [info] f2j 96 97 2 1046.1 1.0 1.0X [info] java 29 30 1 3393.4 0.3 3.2X [info] [info] dscal: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative [info] ------------------------------------------------------------------------------------------------------------------------ [info] f2j 156 165 5 643.0 1.6 1.0X [info] java 150 159 5 667.1 1.5 1.0X [info] [info] sscal: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative [info] ------------------------------------------------------------------------------------------------------------------------ [info] f2j 85 91 6 1171.0 0.9 1.0X [info] java 75 79 3 1340.6 0.7 1.1X [info] [info] dspmv[U]: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative [info] ------------------------------------------------------------------------------------------------------------------------ [info] f2j 1 1 0 917.0 1.1 1.0X [info] java 0 0 0 8147.2 0.1 8.9X [info] [info] dspr[U]: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative [info] ------------------------------------------------------------------------------------------------------------------------ [info] f2j 1 1 0 859.3 1.2 1.0X [info] java 1 1 0 859.3 1.2 1.0X [info] [info] dsyr[U]: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative [info] ------------------------------------------------------------------------------------------------------------------------ [info] f2j 1 1 0 482.1 2.1 1.0X [info] java 1 1 0 482.6 2.1 1.0X [info] [info] dgemv[N]: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative [info] ------------------------------------------------------------------------------------------------------------------------ [info] f2j 0 0 0 2214.2 0.5 1.0X [info] java 0 0 0 7975.8 0.1 3.6X [info] [info] dgemv[T]: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative [info] ------------------------------------------------------------------------------------------------------------------------ [info] f2j 1 1 0 1231.4 0.8 1.0X [info] java 0 0 0 8680.9 0.1 7.0X [info] [info] sgemv[N]: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative [info] ------------------------------------------------------------------------------------------------------------------------ [info] f2j 0 0 0 2684.3 0.4 1.0X [info] java 0 0 0 18527.1 0.1 6.9X [info] [info] sgemv[T]: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative [info] ------------------------------------------------------------------------------------------------------------------------ [info] f2j 1 1 0 1235.4 0.8 1.0X [info] java 0 0 0 17347.9 0.1 14.0X [info] [info] dgemm[N,N]: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative [info] ------------------------------------------------------------------------------------------------------------------------ [info] f2j 530 552 18 1887.5 0.5 1.0X [info] java 58 64 3 17143.9 0.1 9.1X [info] [info] dgemm[N,T]: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative [info] ------------------------------------------------------------------------------------------------------------------------ [info] f2j 598 620 17 1671.1 0.6 1.0X [info] java 58 64 3 17196.6 0.1 10.3X [info] [info] dgemm[T,N]: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative [info] ------------------------------------------------------------------------------------------------------------------------ [info] f2j 834 847 14 1199.4 0.8 1.0X [info] java 57 63 4 17486.9 0.1 14.6X [info] [info] dgemm[T,T]: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative [info] ------------------------------------------------------------------------------------------------------------------------ [info] f2j 1338 1366 22 747.3 1.3 1.0X [info] java 58 63 3 17356.6 0.1 23.2X [info] [info] sgemm[N,N]: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative [info] ------------------------------------------------------------------------------------------------------------------------ [info] f2j 489 501 9 2045.5 0.5 1.0X [info] java 36 38 2 27721.9 0.0 13.6X [info] [info] sgemm[N,T]: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative [info] ------------------------------------------------------------------------------------------------------------------------ [info] f2j 478 488 9 2094.0 0.5 1.0X [info] java 36 38 2 27813.2 0.0 13.3X [info] [info] sgemm[T,N]: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative [info] ------------------------------------------------------------------------------------------------------------------------ [info] f2j 825 837 10 1211.6 0.8 1.0X [info] java 35 38 2 28433.1 0.0 23.5X [info] [info] sgemm[T,T]: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative [info] ------------------------------------------------------------------------------------------------------------------------ [info] f2j 900 918 15 1111.6 0.9 1.0X [info] java 36 38 2 28073.0 0.0 25.3X ``` [2] https://github.com/luhenry/netlib/tree/master/blas/src/test/java/dev/ludovic/netlib/blas Closes #32253 from luhenry/master. Authored-by: Ludovic Henry Signed-off-by: Sean Owen --- dev/deps/spark-deps-hadoop-2.7-hive-2.3 | 3 + dev/deps/spark-deps-hadoop-3.2-hive-2.3 | 3 + docs/ml-linalg-guide.md | 2 +- graphx/pom.xml | 5 +- .../apache/spark/graphx/lib/SVDPlusPlus.scala | 31 +- licenses-binary/LICENSE-blas.txt | 25 + mllib-local/pom.xml | 33 +- .../spark/ml/linalg/VectorizedBLAS.java | 483 ---------------- .../org/apache/spark/ml/linalg/BLAS.scala | 29 +- .../spark/ml/linalg/BLASBenchmark.scala | 530 ++++++++++++------ mllib/pom.xml | 13 + .../apache/spark/mllib/feature/Word2Vec.scala | 24 +- .../apache/spark/mllib/linalg/ARPACK.scala | 46 ++ .../org/apache/spark/mllib/linalg/BLAS.scala | 40 +- .../mllib/linalg/CholeskyDecomposition.scala | 5 +- .../linalg/EigenValueDecomposition.scala | 15 +- .../apache/spark/mllib/linalg/LAPACK.scala | 46 ++ .../apache/spark/mllib/linalg/Matrices.scala | 3 +- .../spark/mllib/optimization/NNLS.scala | 32 +- .../MatrixFactorizationModel.scala | 11 +- .../spark/mllib/stat/KernelDensity.scala | 7 +- .../mllib/tree/model/treeEnsembleModels.scala | 4 +- .../spark/mllib/util/SVMDataGenerator.scala | 6 +- .../apache/spark/mllib/linalg/BLASSuite.scala | 2 +- pom.xml | 16 + project/SparkBuild.scala | 2 +- python/pyspark/ml/recommendation.py | 4 +- 27 files changed, 627 insertions(+), 793 deletions(-) create mode 100644 licenses-binary/LICENSE-blas.txt delete mode 100644 mllib-local/src/jvm-vectorized/java/org/apache/spark/ml/linalg/VectorizedBLAS.java create mode 100644 mllib/src/main/scala/org/apache/spark/mllib/linalg/ARPACK.scala create mode 100644 mllib/src/main/scala/org/apache/spark/mllib/linalg/LAPACK.scala diff --git a/dev/deps/spark-deps-hadoop-2.7-hive-2.3 b/dev/deps/spark-deps-hadoop-2.7-hive-2.3 index 7534268a0a67e..1b1bdf56f778c 100644 --- a/dev/deps/spark-deps-hadoop-2.7-hive-2.3 +++ b/dev/deps/spark-deps-hadoop-2.7-hive-2.3 @@ -15,6 +15,7 @@ apacheds-i18n/2.0.0-M15//apacheds-i18n-2.0.0-M15.jar apacheds-kerberos-codec/2.0.0-M15//apacheds-kerberos-codec-2.0.0-M15.jar api-asn1-api/1.0.0-M20//api-asn1-api-1.0.0-M20.jar api-util/1.0.0-M20//api-util-1.0.0-M20.jar +arpack/1.3.2//arpack-1.3.2.jar arpack_combined_all/0.1//arpack_combined_all-0.1.jar arrow-format/2.0.0//arrow-format-2.0.0.jar arrow-memory-core/2.0.0//arrow-memory-core-2.0.0.jar @@ -25,6 +26,7 @@ automaton/1.11-8//automaton-1.11-8.jar avro-ipc/1.10.2//avro-ipc-1.10.2.jar avro-mapred/1.10.2//avro-mapred-1.10.2.jar avro/1.10.2//avro-1.10.2.jar +blas/1.3.2//blas-1.3.2.jar bonecp/0.8.0.RELEASE//bonecp-0.8.0.RELEASE.jar breeze-macros_2.12/1.0//breeze-macros_2.12-1.0.jar breeze_2.12/1.0//breeze_2.12-1.0.jar @@ -173,6 +175,7 @@ kubernetes-model-policy/5.3.0//kubernetes-model-policy-5.3.0.jar kubernetes-model-rbac/5.3.0//kubernetes-model-rbac-5.3.0.jar kubernetes-model-scheduling/5.3.0//kubernetes-model-scheduling-5.3.0.jar kubernetes-model-storageclass/5.3.0//kubernetes-model-storageclass-5.3.0.jar +lapack/1.3.2//lapack-1.3.2.jar leveldbjni-all/1.8//leveldbjni-all-1.8.jar libfb303/0.9.3//libfb303-0.9.3.jar libthrift/0.12.0//libthrift-0.12.0.jar diff --git a/dev/deps/spark-deps-hadoop-3.2-hive-2.3 b/dev/deps/spark-deps-hadoop-3.2-hive-2.3 index a86b4832b3986..d5d0890c32b86 100644 --- a/dev/deps/spark-deps-hadoop-3.2-hive-2.3 +++ b/dev/deps/spark-deps-hadoop-3.2-hive-2.3 @@ -10,6 +10,7 @@ annotations/17.0.0//annotations-17.0.0.jar antlr-runtime/3.5.2//antlr-runtime-3.5.2.jar antlr4-runtime/4.8-1//antlr4-runtime-4.8-1.jar aopalliance-repackaged/2.6.1//aopalliance-repackaged-2.6.1.jar +arpack/1.3.2//arpack-1.3.2.jar arpack_combined_all/0.1//arpack_combined_all-0.1.jar arrow-format/2.0.0//arrow-format-2.0.0.jar arrow-memory-core/2.0.0//arrow-memory-core-2.0.0.jar @@ -20,6 +21,7 @@ automaton/1.11-8//automaton-1.11-8.jar avro-ipc/1.10.2//avro-ipc-1.10.2.jar avro-mapred/1.10.2//avro-mapred-1.10.2.jar avro/1.10.2//avro-1.10.2.jar +blas/1.3.2//blas-1.3.2.jar bonecp/0.8.0.RELEASE//bonecp-0.8.0.RELEASE.jar breeze-macros_2.12/1.0//breeze-macros_2.12-1.0.jar breeze_2.12/1.0//breeze_2.12-1.0.jar @@ -144,6 +146,7 @@ kubernetes-model-policy/5.3.0//kubernetes-model-policy-5.3.0.jar kubernetes-model-rbac/5.3.0//kubernetes-model-rbac-5.3.0.jar kubernetes-model-scheduling/5.3.0//kubernetes-model-scheduling-5.3.0.jar kubernetes-model-storageclass/5.3.0//kubernetes-model-storageclass-5.3.0.jar +lapack/1.3.2//lapack-1.3.2.jar leveldbjni-all/1.8//leveldbjni-all-1.8.jar libfb303/0.9.3//libfb303-0.9.3.jar libthrift/0.12.0//libthrift-0.12.0.jar diff --git a/docs/ml-linalg-guide.md b/docs/ml-linalg-guide.md index 739091363473f..719554af5a2d2 100644 --- a/docs/ml-linalg-guide.md +++ b/docs/ml-linalg-guide.md @@ -82,7 +82,7 @@ WARN BLAS: Failed to load implementation from:com.github.fommil.netlib.NativeSys WARN BLAS: Failed to load implementation from:com.github.fommil.netlib.NativeRefBLAS ``` -If native libraries are not properly configured in the system, the Java implementation (f2jBLAS) will be used as fallback option. +If native libraries are not properly configured in the system, the Java implementation (javaBLAS) will be used as fallback option. ## Spark Configuration diff --git a/graphx/pom.xml b/graphx/pom.xml index 3ed68c0652711..c4fa38a1dc9e5 100644 --- a/graphx/pom.xml +++ b/graphx/pom.xml @@ -60,9 +60,8 @@ guava - com.github.fommil.netlib - core - ${netlib.java.version} + dev.ludovic.netlib + blas net.sourceforge.f2j diff --git a/graphx/src/main/scala/org/apache/spark/graphx/lib/SVDPlusPlus.scala b/graphx/src/main/scala/org/apache/spark/graphx/lib/SVDPlusPlus.scala index db786a194e19c..d7099c5c953c1 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/lib/SVDPlusPlus.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/lib/SVDPlusPlus.scala @@ -19,9 +19,8 @@ package org.apache.spark.graphx.lib import scala.util.Random -import com.github.fommil.netlib.BLAS.{getInstance => blas} - import org.apache.spark.graphx._ +import org.apache.spark.ml.linalg.BLAS import org.apache.spark.rdd._ /** Implementation of SVD++ algorithm. */ @@ -102,22 +101,22 @@ object SVDPlusPlus { val (usr, itm) = (ctx.srcAttr, ctx.dstAttr) val (p, q) = (usr._1, itm._1) val rank = p.length - var pred = u + usr._3 + itm._3 + blas.ddot(rank, q, 1, usr._2, 1) + var pred = u + usr._3 + itm._3 + BLAS.nativeBLAS.ddot(rank, q, 1, usr._2, 1) pred = math.max(pred, conf.minVal) pred = math.min(pred, conf.maxVal) val err = ctx.attr - pred // updateP = (err * q - conf.gamma7 * p) * conf.gamma2 val updateP = q.clone() - blas.dscal(rank, err * conf.gamma2, updateP, 1) - blas.daxpy(rank, -conf.gamma7 * conf.gamma2, p, 1, updateP, 1) + BLAS.nativeBLAS.dscal(rank, err * conf.gamma2, updateP, 1) + BLAS.nativeBLAS.daxpy(rank, -conf.gamma7 * conf.gamma2, p, 1, updateP, 1) // updateQ = (err * usr._2 - conf.gamma7 * q) * conf.gamma2 val updateQ = usr._2.clone() - blas.dscal(rank, err * conf.gamma2, updateQ, 1) - blas.daxpy(rank, -conf.gamma7 * conf.gamma2, q, 1, updateQ, 1) + BLAS.nativeBLAS.dscal(rank, err * conf.gamma2, updateQ, 1) + BLAS.nativeBLAS.daxpy(rank, -conf.gamma7 * conf.gamma2, q, 1, updateQ, 1) // updateY = (err * usr._4 * q - conf.gamma7 * itm._2) * conf.gamma2 val updateY = q.clone() - blas.dscal(rank, err * usr._4 * conf.gamma2, updateY, 1) - blas.daxpy(rank, -conf.gamma7 * conf.gamma2, itm._2, 1, updateY, 1) + BLAS.nativeBLAS.dscal(rank, err * usr._4 * conf.gamma2, updateY, 1) + BLAS.nativeBLAS.daxpy(rank, -conf.gamma7 * conf.gamma2, itm._2, 1, updateY, 1) ctx.sendToSrc((updateP, updateY, (err - conf.gamma6 * usr._3) * conf.gamma1)) ctx.sendToDst((updateQ, updateY, (err - conf.gamma6 * itm._3) * conf.gamma1)) } @@ -129,7 +128,7 @@ object SVDPlusPlus { ctx => ctx.sendToSrc(ctx.dstAttr._2), (g1, g2) => { val out = g1.clone() - blas.daxpy(out.length, 1.0, g2, 1, out, 1) + BLAS.nativeBLAS.daxpy(out.length, 1.0, g2, 1, out, 1) out }) val gJoinT1 = g.outerJoinVertices(t1) { @@ -137,7 +136,7 @@ object SVDPlusPlus { msg: Option[Array[Double]]) => if (msg.isDefined) { val out = vd._1.clone() - blas.daxpy(out.length, vd._4, msg.get, 1, out, 1) + BLAS.nativeBLAS.daxpy(out.length, vd._4, msg.get, 1, out, 1) (vd._1, out, vd._3, vd._4) } else { vd @@ -154,9 +153,9 @@ object SVDPlusPlus { (g1: (Array[Double], Array[Double], Double), g2: (Array[Double], Array[Double], Double)) => { val out1 = g1._1.clone() - blas.daxpy(out1.length, 1.0, g2._1, 1, out1, 1) + BLAS.nativeBLAS.daxpy(out1.length, 1.0, g2._1, 1, out1, 1) val out2 = g2._2.clone() - blas.daxpy(out2.length, 1.0, g2._2, 1, out2, 1) + BLAS.nativeBLAS.daxpy(out2.length, 1.0, g2._2, 1, out2, 1) (out1, out2, g1._3 + g2._3) }) val gJoinT2 = g.outerJoinVertices(t2) { @@ -164,9 +163,9 @@ object SVDPlusPlus { vd: (Array[Double], Array[Double], Double, Double), msg: Option[(Array[Double], Array[Double], Double)]) => { val out1 = vd._1.clone() - blas.daxpy(out1.length, 1.0, msg.get._1, 1, out1, 1) + BLAS.nativeBLAS.daxpy(out1.length, 1.0, msg.get._1, 1, out1, 1) val out2 = vd._2.clone() - blas.daxpy(out2.length, 1.0, msg.get._2, 1, out2, 1) + BLAS.nativeBLAS.daxpy(out2.length, 1.0, msg.get._2, 1, out2, 1) (out1, out2, vd._3 + msg.get._3, vd._4) } }.cache() @@ -180,7 +179,7 @@ object SVDPlusPlus { (ctx: EdgeContext[(Array[Double], Array[Double], Double, Double), Double, Double]): Unit = { val (usr, itm) = (ctx.srcAttr, ctx.dstAttr) val (p, q) = (usr._1, itm._1) - var pred = u + usr._3 + itm._3 + blas.ddot(q.length, q, 1, usr._2, 1) + var pred = u + usr._3 + itm._3 + BLAS.nativeBLAS.ddot(q.length, q, 1, usr._2, 1) pred = math.max(pred, conf.minVal) pred = math.min(pred, conf.maxVal) val err = (ctx.attr - pred) * (ctx.attr - pred) diff --git a/licenses-binary/LICENSE-blas.txt b/licenses-binary/LICENSE-blas.txt new file mode 100644 index 0000000000000..2b8bec28b0d3b --- /dev/null +++ b/licenses-binary/LICENSE-blas.txt @@ -0,0 +1,25 @@ +MIT License +----------- + +Copyright 2020, 2021, Ludovic Henry + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in +all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. + +Please contact git@ludovic.dev or visit ludovic.dev if you need additional +information or have any questions. diff --git a/mllib-local/pom.xml b/mllib-local/pom.xml index 16cd55c8a45f0..a977ae3c10280 100644 --- a/mllib-local/pom.xml +++ b/mllib-local/pom.xml @@ -75,6 +75,11 @@ test-jar test + + + dev.ludovic.netlib + blas + @@ -88,34 +93,6 @@ - - jvm-vectorized - - src/jvm-vectorized/java - - - - - org.codehaus.mojo - build-helper-maven-plugin - - - add-vectorized-sources - generate-sources - - add-source - - - - ${extra.source.dir} - - - - - - - - target/scala-${scala.binary.version}/classes diff --git a/mllib-local/src/jvm-vectorized/java/org/apache/spark/ml/linalg/VectorizedBLAS.java b/mllib-local/src/jvm-vectorized/java/org/apache/spark/ml/linalg/VectorizedBLAS.java deleted file mode 100644 index 7db1bb9111a00..0000000000000 --- a/mllib-local/src/jvm-vectorized/java/org/apache/spark/ml/linalg/VectorizedBLAS.java +++ /dev/null @@ -1,483 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.ml.linalg; - -import com.github.fommil.netlib.F2jBLAS; -import jdk.incubator.vector.DoubleVector; -import jdk.incubator.vector.FloatVector; -import jdk.incubator.vector.VectorOperators; -import jdk.incubator.vector.VectorSpecies; - -public class VectorizedBLAS extends F2jBLAS { - - private static final VectorSpecies FMAX = FloatVector.SPECIES_MAX; - private static final VectorSpecies DMAX = DoubleVector.SPECIES_MAX; - - // y += alpha * x - @Override - public void daxpy(int n, double alpha, double[] x, int incx, double[] y, int incy) { - if (n >= 0 - && x != null && x.length >= n && incx == 1 - && y != null && y.length >= n && incy == 1) { - if (alpha != 0.) { - DoubleVector valpha = DoubleVector.broadcast(DMAX, alpha); - int i = 0; - for (; i < DMAX.loopBound(n); i += DMAX.length()) { - DoubleVector vx = DoubleVector.fromArray(DMAX, x, i); - DoubleVector vy = DoubleVector.fromArray(DMAX, y, i); - vx.fma(valpha, vy).intoArray(y, i); - } - for (; i < n; i += 1) { - y[i] += alpha * x[i]; - } - } - } else { - super.daxpy(n, alpha, x, incx, y, incy); - } - } - - // sum(x * y) - @Override - public float sdot(int n, float[] x, int incx, float[] y, int incy) { - if (n >= 0 - && x != null && x.length >= n && incx == 1 - && y != null && y.length >= n && incy == 1) { - float sum = 0.0f; - int i = 0; - FloatVector vsum = FloatVector.zero(FMAX); - for (; i < FMAX.loopBound(n); i += FMAX.length()) { - FloatVector vx = FloatVector.fromArray(FMAX, x, i); - FloatVector vy = FloatVector.fromArray(FMAX, y, i); - vsum = vx.fma(vy, vsum); - } - sum += vsum.reduceLanes(VectorOperators.ADD); - for (; i < n; i += 1) { - sum += x[i] * y[i]; - } - return sum; - } else { - return super.sdot(n, x, incx, y, incy); - } - } - - // sum(x * y) - @Override - public double ddot(int n, double[] x, int incx, double[] y, int incy) { - if (n >= 0 - && x != null && x.length >= n && incx == 1 - && y != null && y.length >= n && incy == 1) { - double sum = 0.; - int i = 0; - DoubleVector vsum = DoubleVector.zero(DMAX); - for (; i < DMAX.loopBound(n); i += DMAX.length()) { - DoubleVector vx = DoubleVector.fromArray(DMAX, x, i); - DoubleVector vy = DoubleVector.fromArray(DMAX, y, i); - vsum = vx.fma(vy, vsum); - } - sum += vsum.reduceLanes(VectorOperators.ADD); - for (; i < n; i += 1) { - sum += x[i] * y[i]; - } - return sum; - } else { - return super.ddot(n, x, incx, y, incy); - } - } - - @Override - public void dscal(int n, double alpha, double[] x, int incx) { - dscal(n, alpha, x, 0, incx); - } - - // x = alpha * x - @Override - public void dscal(int n, double alpha, double[] x, int offsetx, int incx) { - if (n >= 0 && x != null && x.length >= offsetx + n && incx == 1) { - if (alpha != 1.) { - DoubleVector valpha = DoubleVector.broadcast(DMAX, alpha); - int i = 0; - for (; i < DMAX.loopBound(n); i += DMAX.length()) { - DoubleVector vx = DoubleVector.fromArray(DMAX, x, offsetx + i); - vx.mul(valpha).intoArray(x, offsetx + i); - } - for (; i < n; i += 1) { - x[offsetx + i] *= alpha; - } - } - } else { - super.dscal(n, alpha, x, offsetx, incx); - } - } - - @Override - public void sscal(int n, float alpha, float[] x, int incx) { - sscal(n, alpha, x, 0, incx); - } - - // x = alpha * x - @Override - public void sscal(int n, float alpha, float[] x, int offsetx, int incx) { - if (n >= 0 && x != null && x.length >= offsetx + n && incx == 1) { - if (alpha != 1.) { - FloatVector valpha = FloatVector.broadcast(FMAX, alpha); - int i = 0; - for (; i < FMAX.loopBound(n); i += FMAX.length()) { - FloatVector vx = FloatVector.fromArray(FMAX, x, offsetx + i); - vx.mul(valpha).intoArray(x, offsetx + i); - } - for (; i < n; i += 1) { - x[offsetx + i] *= alpha; - } - } - } else { - super.sscal(n, alpha, x, offsetx, incx); - } - } - - // y = alpha * a * x + beta * y - @Override - public void dspmv(String uplo, int n, double alpha, double[] a, - double[] x, int incx, double beta, double[] y, int incy) { - if ("U".equals(uplo) - && n >= 0 - && a != null && a.length >= n * (n + 1) / 2 - && x != null && x.length >= n && incx == 1 - && y != null && y.length >= n && incy == 1) { - // y = beta * y - dscal(n, beta, y, 1); - // y += alpha * A * x - if (alpha != 0.) { - DoubleVector valpha = DoubleVector.broadcast(DMAX, alpha); - for (int row = 0; row < n; row += 1) { - int col = 0; - DoubleVector vyrowsum = DoubleVector.zero(DMAX); - DoubleVector valphaxrow = DoubleVector.broadcast(DMAX, alpha * x[row]); - for (; col < DMAX.loopBound(row); col += DMAX.length()) { - DoubleVector vx = DoubleVector.fromArray(DMAX, x, col); - DoubleVector vy = DoubleVector.fromArray(DMAX, y, col); - DoubleVector va = DoubleVector.fromArray(DMAX, a, col + row * (row + 1) / 2); - vyrowsum = valpha.mul(vx).fma(va, vyrowsum); - valphaxrow.fma(va, vy).intoArray(y, col); - } - y[row] += vyrowsum.reduceLanes(VectorOperators.ADD); - for (; col < row; col += 1) { - y[row] += alpha * x[col] * a[col + row * (row + 1) / 2]; - y[col] += alpha * x[row] * a[col + row * (row + 1) / 2]; - } - y[row] += alpha * x[col] * a[col + row * (row + 1) / 2]; - } - } - } else { - super.dspmv(uplo, n, alpha, a, x, incx, beta, y, incy); - } - } - - // a += alpha * x * x.t - @Override - public void dspr(String uplo, int n, double alpha, double[] x, int incx, double[] a) { - if ("U".equals(uplo) - && n >= 0 - && x != null && x.length >= n && incx == 1 - && a != null && a.length >= n * (n + 1) / 2) { - if (alpha != 0.) { - for (int row = 0; row < n; row += 1) { - int col = 0; - DoubleVector valphaxrow = DoubleVector.broadcast(DMAX, alpha * x[row]); - for (; col < DMAX.loopBound(row + 1); col += DMAX.length()) { - DoubleVector vx = DoubleVector.fromArray(DMAX, x, col); - DoubleVector va = DoubleVector.fromArray(DMAX, a, col + row * (row + 1) / 2); - vx.fma(valphaxrow, va).intoArray(a, col + row * (row + 1) / 2); - } - for (; col < row + 1; col += 1) { - a[col + row * (row + 1) / 2] += alpha * x[row] * x[col]; - } - } - } - } else { - super.dspr(uplo, n, alpha, x, incx, a); - } - } - - // a += alpha * x * x.t - @Override - public void dsyr(String uplo, int n, double alpha, double[] x, int incx, double[] a, int lda) { - if ("U".equals(uplo) - && n >= 0 - && x != null && x.length >= n && incx == 1 - && a != null && a.length >= n * n && lda == n) { - if (alpha != 0.) { - for (int row = 0; row < n; row += 1) { - int col = 0; - DoubleVector valphaxrow = DoubleVector.broadcast(DMAX, alpha * x[row]); - for (; col < DMAX.loopBound(row + 1); col += DMAX.length()) { - DoubleVector vx = DoubleVector.fromArray(DMAX, x, col); - DoubleVector va = DoubleVector.fromArray(DMAX, a, col + row * n); - vx.fma(valphaxrow, va).intoArray(a, col + row * n); - } - for (; col < row + 1; col += 1) { - a[col + row * n] += alpha * x[row] * x[col]; - } - } - } - } else { - super.dsyr(uplo, n, alpha, x, incx, a, lda); - } - } - - @Override - public void dgemv(String trans, int m, int n, - double alpha, double[] a, int lda, double[] x, int incx, - double beta, double[] y, int incy) { - dgemv(trans, m, n, alpha, a, 0, lda, x, 0, incx, beta, y, 0, incy); - } - - // y = alpha * A * x + beta * y - @Override - public void dgemv(String trans, int m, int n, - double alpha, double[] a, int offseta, int lda, double[] x, int offsetx, int incx, - double beta, double[] y, int offsety, int incy) { - if ("N".equals(trans) - && m >= 0 && n >= 0 - && a != null && a.length >= offseta + m * n && lda == m - && x != null && x.length >= offsetx + n && incx == 1 - && y != null && y.length >= offsety + m && incy == 1) { - // y = beta * y - dscal(m, beta, y, offsety, 1); - // y += alpha * A * x - if (alpha != 0.) { - DoubleVector valpha = DoubleVector.broadcast(DMAX, alpha); - for (int col = 0; col < n; col += 1) { - int row = 0; - for (; row < DMAX.loopBound(m); row += DMAX.length()) { - DoubleVector va = DoubleVector.fromArray(DMAX, a, offseta + row + col * m); - DoubleVector vy = DoubleVector.fromArray(DMAX, y, offsety + row); - valpha.mul(x[offsetx + col]).fma(va, vy) - .intoArray(y, offsety + row); - } - for (; row < m; row += 1) { - y[offsety + row] += alpha * x[offsetx + col] * a[offseta + row + col * m]; - } - } - } - } else if ("T".equals(trans) - && m >= 0 && n >= 0 - && a != null && a.length >= offseta + m * n && lda == m - && x != null && x.length >= offsetx + m && incx == 1 - && y != null && y.length >= offsety + n && incy == 1) { - if (alpha != 0. || beta != 1.) { - for (int col = 0; col < n; col += 1) { - double sum = 0.; - int row = 0; - DoubleVector vsum = DoubleVector.zero(DMAX); - for (; row < DMAX.loopBound(m); row += DMAX.length()) { - DoubleVector va = DoubleVector.fromArray(DMAX, a, offseta + row + col * m); - DoubleVector vx = DoubleVector.fromArray(DMAX, x, offsetx + row); - vsum = va.fma(vx, vsum); - } - sum += vsum.reduceLanes(VectorOperators.ADD); - for (; row < m; row += 1) { - sum += x[offsetx + row] * a[offseta + row + col * m]; - } - y[offsety + col] = alpha * sum + beta * y[offsety + col]; - } - } - } else { - super.dgemv(trans, m, n, alpha, a, offseta, lda, x, offsetx, incx, beta, y, offsety, incy); - } - } - - @Override - public void sgemv(String trans, int m, int n, - float alpha, float[] a, int lda, float[] x, int incx, - float beta, float[] y, int incy) { - sgemv(trans, m, n, alpha, a, 0, lda, x, 0, incx, beta, y, 0, incy); - } - - // y = alpha * A * x + beta * y - @Override - public void sgemv(String trans, int m, int n, - float alpha, float[] a, int offseta, int lda, float[] x, int offsetx, int incx, - float beta, float[] y, int offsety, int incy) { - if ("N".equals(trans) - && m >= 0 && n >= 0 - && a != null && a.length >= offseta + m * n && lda == m - && x != null && x.length >= offsetx + n && incx == 1 - && y != null && y.length >= offsety + m && incy == 1) { - // y = beta * y - sscal(m, beta, y, offsety, 1); - // y += alpha * A * x - if (alpha != 0.f) { - FloatVector valpha = FloatVector.broadcast(FMAX, alpha); - for (int col = 0; col < n; col += 1) { - int row = 0; - for (; row < FMAX.loopBound(m); row += FMAX.length()) { - FloatVector va = FloatVector.fromArray(FMAX, a, offseta + row + col * m); - FloatVector vy = FloatVector.fromArray(FMAX, y, offsety + row); - valpha.mul(x[offsetx + col]).fma(va, vy) - .intoArray(y, offsety + row); - } - for (; row < m; row += 1) { - y[offsety + row] += alpha * x[offsetx + col] * a[offseta + row + col * m]; - } - } - } - } else if ("T".equals(trans) - && m >= 0 && n >= 0 - && a != null && a.length >= offseta + m * n && lda == m - && x != null && x.length >= offsetx + m && incx == 1 - && y != null && y.length >= offsety + n && incy == 1) { - if (alpha != 0. || beta != 1.) { - for (int col = 0; col < n; col += 1) { - float sum = 0.f; - int row = 0; - FloatVector vsum = FloatVector.zero(FMAX); - for (; row < FMAX.loopBound(m); row += FMAX.length()) { - FloatVector va = FloatVector.fromArray(FMAX, a, offseta + row + col * m); - FloatVector vx = FloatVector.fromArray(FMAX, x, offsetx + row); - vsum = va.fma(vx, vsum); - } - sum += vsum.reduceLanes(VectorOperators.ADD); - for (; row < m; row += 1) { - sum += x[offsetx + row] * a[offseta + row + col * m]; - } - y[offsety + col] = alpha * sum + beta * y[offsety + col]; - } - } - } else { - super.sgemv(trans, m, n, alpha, a, offseta, lda, x, offsetx, incx, beta, y, offsety, incy); - } - } - - @Override - public void dgemm(String transa, String transb, int m, int n, int k, - double alpha, double[] a, int lda, double[] b, int ldb, - double beta, double[] c, int ldc) { - dgemm(transa, transb, m, n, k, alpha, a, 0, lda, b, 0, ldb, beta, c, 0, ldc); - } - - // c = alpha * a * b + beta * c - @Override - public void dgemm(String transa, String transb, int m, int n, int k, - double alpha, double[] a, int offseta, int lda, double[] b, int offsetb, int ldb, - double beta, double[] c, int offsetc, int ldc) { - if ("N".equals(transa) && "N".equals(transb) - && m >= 0 && n >= 0 && k >= 0 - && a != null && a.length >= offseta + m * k && lda == m - && b != null && b.length >= offsetb + k * n && ldb == k - && c != null && c.length >= offsetc + m * n && ldc == m) { - // C = beta * C - dscal(m * n, beta, c, offsetc, 1); - // C += alpha * A * B - if (alpha != 0.) { - DoubleVector valpha = DoubleVector.broadcast(DMAX, alpha); - for (int col = 0; col < n; col += 1) { - for (int i = 0; i < k; i += 1) { - int row = 0; - for (; row < DMAX.loopBound(m); row += DMAX.length()) { - DoubleVector va = DoubleVector.fromArray(DMAX, a, offseta + i * m + row); - DoubleVector vc = DoubleVector.fromArray(DMAX, c, offsetc + col * m + row); - valpha.mul(b[offsetb + col * k + i]).fma(va, vc) - .intoArray(c, offsetc + col * m + row); - } - for (; row < m; row += 1) { - c[offsetc + col * m + row] += alpha * a[offseta + i * m + row] * b[offsetb + col * k + i]; - } - } - } - } - } else if ("N".equals(transa) && "T".equals(transb) - && m >= 0 && n >= 0 && k >= 0 - && a != null && a.length >= offseta + m * k && lda == m - && b != null && b.length >= offsetb + k * n && ldb == n - && c != null && c.length >= offsetc + m * n && ldc == m) { - // C = beta * C - dscal(m * n, beta, c, offsetc, 1); - // C += alpha * A * B - if (alpha != 0.) { - DoubleVector valpha = DoubleVector.broadcast(DMAX, alpha); - for (int i = 0; i < k; i += 1) { - for (int col = 0; col < n; col += 1) { - int row = 0; - for (; row < DMAX.loopBound(m); row += DMAX.length()) { - DoubleVector va = DoubleVector.fromArray(DMAX, a, offseta + i * m + row); - DoubleVector vc = DoubleVector.fromArray(DMAX, c, offsetc + col * m + row); - valpha.mul(b[offsetb + col + i * n]).fma(va, vc) - .intoArray(c, offsetc + col * m + row); - } - for (; row < m; row += 1) { - c[offsetc + col * m + row] += alpha * a[offseta + i * m + row] * b[offsetb + col + i * n]; - } - } - } - } - } else if ("T".equals(transa) && "N".equals(transb) - && m >= 0 && n >= 0 && k >= 0 - && a != null && a.length >= offseta + m * k && lda == k - && b != null && b.length >= offsetb + k * n && ldb == k - && c != null && c.length >= offsetc + m * n && ldc == m) { - if (alpha != 0. || beta != 1.) { - for (int col = 0; col < n; col += 1) { - for (int row = 0; row < m; row += 1) { - double sum = 0.; - int i = 0; - DoubleVector vsum = DoubleVector.zero(DMAX); - for (; i < DMAX.loopBound(k); i += DMAX.length()) { - DoubleVector va = DoubleVector.fromArray(DMAX, a, offseta + i + row * k); - DoubleVector vb = DoubleVector.fromArray(DMAX, b, offsetb + col * k + i); - vsum = va.fma(vb, vsum); - } - sum += vsum.reduceLanes(VectorOperators.ADD); - for (; i < k; i += 1) { - sum += a[offseta + i + row * k] * b[offsetb + col * k + i]; - } - if (beta != 0.) { - c[offsetc + col * m + row] = alpha * sum + beta * c[offsetc + col * m + row]; - } else { - c[offsetc + col * m + row] = alpha * sum; - } - } - } - } - } else if ("T".equals(transa) && "T".equals(transb) - && m >= 0 && n >= 0 && k >= 0 - && a != null && a.length >= offseta + m * k && lda == k - && b != null && b.length >= offsetb + k * n && ldb == n - && c != null && c.length >= offsetc + m * n && ldc == m) { - if (alpha != 0. || beta != 1.) { - // FIXME: do block by block - for (int col = 0; col < n; col += 1) { - for (int row = 0; row < m; row += 1) { - double sum = 0.; - for (int i = 0; i < k; i += 1) { - sum += a[offseta + i + row * k] * b[offsetb + col + i * n]; - } - if (beta != 0.) { - c[offsetc + col * m + row] = alpha * sum + beta * c[offsetc + col * m + row]; - } else { - c[offsetc + col * m + row] = alpha * sum; - } - } - } - } - } else { - super.dgemm(transa, transb, m, n, k, - alpha, a, offseta, lda, b, offsetb, ldb, - beta, c, offsetc, ldc); - } - } -} diff --git a/mllib-local/src/main/scala/org/apache/spark/ml/linalg/BLAS.scala b/mllib-local/src/main/scala/org/apache/spark/ml/linalg/BLAS.scala index 518c71129a970..5a6bee3e74ead 100644 --- a/mllib-local/src/main/scala/org/apache/spark/ml/linalg/BLAS.scala +++ b/mllib-local/src/main/scala/org/apache/spark/ml/linalg/BLAS.scala @@ -17,7 +17,9 @@ package org.apache.spark.ml.linalg -import com.github.fommil.netlib.{BLAS => NetlibBLAS, F2jBLAS} +import dev.ludovic.netlib.{BLAS => NetlibBLAS, + JavaBLAS => NetlibJavaBLAS, + NativeBLAS => NetlibNativeBLAS} /** * BLAS routines for MLlib's vectors and matrices. @@ -29,38 +31,23 @@ private[spark] object BLAS extends Serializable { private val nativeL1Threshold: Int = 256 // For level-1 function dspmv, use javaBLAS for better performance. - private[ml] def javaBLAS: NetlibBLAS = { + private[spark] def javaBLAS: NetlibBLAS = { if (_javaBLAS == null) { - _javaBLAS = - try { - // scalastyle:off classforname - Class.forName("org.apache.spark.ml.linalg.VectorizedBLAS", true, - Option(Thread.currentThread().getContextClassLoader) - .getOrElse(getClass.getClassLoader)) - .newInstance() - .asInstanceOf[NetlibBLAS] - // scalastyle:on classforname - } catch { - case _: Throwable => new F2jBLAS - } + _javaBLAS = NetlibJavaBLAS.getInstance } _javaBLAS } // For level-3 routines, we use the native BLAS. - private[ml] def nativeBLAS: NetlibBLAS = { + private[spark] def nativeBLAS: NetlibBLAS = { if (_nativeBLAS == null) { _nativeBLAS = - if (NetlibBLAS.getInstance.isInstanceOf[F2jBLAS]) { - javaBLAS - } else { - NetlibBLAS.getInstance - } + try { NetlibNativeBLAS.getInstance } catch { case _: Throwable => javaBLAS } } _nativeBLAS } - private[ml] def getBLAS(vectorSize: Int): NetlibBLAS = { + private[spark] def getBLAS(vectorSize: Int): NetlibBLAS = { if (vectorSize < nativeL1Threshold) { javaBLAS } else { diff --git a/mllib-local/src/test/scala/org/apache/spark/ml/linalg/BLASBenchmark.scala b/mllib-local/src/test/scala/org/apache/spark/ml/linalg/BLASBenchmark.scala index 1dcfcf9ebb034..144f59ac172fe 100644 --- a/mllib-local/src/test/scala/org/apache/spark/ml/linalg/BLASBenchmark.scala +++ b/mllib-local/src/test/scala/org/apache/spark/ml/linalg/BLASBenchmark.scala @@ -17,7 +17,8 @@ package org.apache.spark.ml.linalg -import com.github.fommil.netlib.{BLAS => NetlibBLAS, F2jBLAS} +import dev.ludovic.netlib.blas.NetlibF2jBLAS +import scala.concurrent.duration._ import org.apache.spark.benchmark.{Benchmark, BenchmarkBase} @@ -38,48 +39,66 @@ object BLASBenchmark extends BenchmarkBase { val iters = 1e2.toInt val rnd = new scala.util.Random(0) - val f2jBLAS = new F2jBLAS - val nativeBLAS = NetlibBLAS.getInstance - val vectorBLAS = - try { - // scalastyle:off classforname - Class.forName("org.apache.spark.ml.linalg.VectorizedBLAS", true, - Option(Thread.currentThread().getContextClassLoader) - .getOrElse(getClass.getClassLoader)) - .newInstance() - .asInstanceOf[NetlibBLAS] - // scalastyle:on classforname - } catch { - case _: Throwable => new F2jBLAS - } + val f2jBLAS = NetlibF2jBLAS.getInstance + val javaBLAS = BLAS.javaBLAS + val nativeBLAS = BLAS.nativeBLAS // scalastyle:off println - println("nativeBLAS = " + nativeBLAS.getClass.getName) println("f2jBLAS = " + f2jBLAS.getClass.getName) - println("vectorBLAS = " + vectorBLAS.getClass.getName) + println("javaBLAS = " + javaBLAS.getClass.getName) + println("nativeBLAS = " + nativeBLAS.getClass.getName) // scalastyle:on println runBenchmark("daxpy") { - val n = 1e7.toInt + val n = 1e8.toInt val alpha = rnd.nextDouble val x = Array.fill(n) { rnd.nextDouble } val y = Array.fill(n) { rnd.nextDouble } - val benchmark = new Benchmark("daxpy", n, iters, output = output) + val benchmark = new Benchmark("daxpy", n, iters, + warmupTime = 30.seconds, + minTime = 30.seconds, + output = output) benchmark.addCase("f2j") { _ => - f2jBLAS.daxpy(n, alpha, x, 1, y, 1) + f2jBLAS.daxpy(n, alpha, x, 1, y.clone, 1) } - if (!nativeBLAS.getClass.equals(classOf[F2jBLAS])) { + benchmark.addCase("java") { _ => + javaBLAS.daxpy(n, alpha, x, 1, y.clone, 1) + } + + if (nativeBLAS != javaBLAS) { benchmark.addCase("native") { _ => - nativeBLAS.daxpy(n, alpha, x, 1, y, 1) + nativeBLAS.daxpy(n, alpha, x, 1, y.clone, 1) } } - if (!vectorBLAS.getClass.equals(classOf[F2jBLAS])) { - benchmark.addCase("vector") { _ => - vectorBLAS.daxpy(n, alpha, x, 1, y, 1) + benchmark.run() + } + + runBenchmark("saxpy") { + val n = 1e8.toInt + val alpha = rnd.nextFloat + val x = Array.fill(n) { rnd.nextFloat } + val y = Array.fill(n) { rnd.nextFloat } + + val benchmark = new Benchmark("saxpy", n, iters, + warmupTime = 30.seconds, + minTime = 30.seconds, + output = output) + + benchmark.addCase("f2j") { _ => + f2jBLAS.saxpy(n, alpha, x, 1, y.clone, 1) + } + + benchmark.addCase("java") { _ => + javaBLAS.saxpy(n, alpha, x, 1, y.clone, 1) + } + + if (nativeBLAS != javaBLAS) { + benchmark.addCase("native") { _ => + nativeBLAS.saxpy(n, alpha, x, 1, y.clone, 1) } } @@ -87,25 +106,26 @@ object BLASBenchmark extends BenchmarkBase { } runBenchmark("ddot") { - val n = 1e7.toInt + val n = 1e8.toInt val x = Array.fill(n) { rnd.nextDouble } val y = Array.fill(n) { rnd.nextDouble } - val benchmark = new Benchmark("ddot", n, iters, output = output) + val benchmark = new Benchmark("ddot", n, iters, + warmupTime = 30.seconds, + minTime = 30.seconds, + output = output) benchmark.addCase("f2j") { _ => f2jBLAS.ddot(n, x, 1, y, 1) } - if (!nativeBLAS.getClass.equals(classOf[F2jBLAS])) { - benchmark.addCase("native") { _ => - nativeBLAS.ddot(n, x, 1, y, 1) - } + benchmark.addCase("java") { _ => + javaBLAS.ddot(n, x, 1, y, 1) } - if (!vectorBLAS.getClass.equals(classOf[F2jBLAS])) { - benchmark.addCase("vector") { _ => - vectorBLAS.ddot(n, x, 1, y, 1) + if (nativeBLAS != javaBLAS) { + benchmark.addCase("native") { _ => + nativeBLAS.ddot(n, x, 1, y, 1) } } @@ -113,25 +133,26 @@ object BLASBenchmark extends BenchmarkBase { } runBenchmark("sdot") { - val n = 1e7.toInt + val n = 1e8.toInt val x = Array.fill(n) { rnd.nextFloat } val y = Array.fill(n) { rnd.nextFloat } - val benchmark = new Benchmark("sdot", n, iters, output = output) + val benchmark = new Benchmark("sdot", n, iters, + warmupTime = 30.seconds, + minTime = 30.seconds, + output = output) benchmark.addCase("f2j") { _ => f2jBLAS.sdot(n, x, 1, y, 1) } - if (!nativeBLAS.getClass.equals(classOf[F2jBLAS])) { - benchmark.addCase("native") { _ => - nativeBLAS.sdot(n, x, 1, y, 1) - } + benchmark.addCase("java") { _ => + javaBLAS.sdot(n, x, 1, y, 1) } - if (!vectorBLAS.getClass.equals(classOf[F2jBLAS])) { - benchmark.addCase("vector") { _ => - vectorBLAS.sdot(n, x, 1, y, 1) + if (nativeBLAS != javaBLAS) { + benchmark.addCase("native") { _ => + nativeBLAS.sdot(n, x, 1, y, 1) } } @@ -139,25 +160,26 @@ object BLASBenchmark extends BenchmarkBase { } runBenchmark("dscal") { - val n = 1e7.toInt + val n = 1e8.toInt val alpha = rnd.nextDouble val x = Array.fill(n) { rnd.nextDouble } - val benchmark = new Benchmark("dscal", n, iters, output = output) + val benchmark = new Benchmark("dscal", n, iters, + warmupTime = 30.seconds, + minTime = 30.seconds, + output = output) benchmark.addCase("f2j") { _ => - f2jBLAS.dscal(n, alpha, x, 1) + f2jBLAS.dscal(n, alpha, x.clone, 1) } - if (!nativeBLAS.getClass.equals(classOf[F2jBLAS])) { - benchmark.addCase("native") { _ => - nativeBLAS.dscal(n, alpha, x, 1) - } + benchmark.addCase("java") { _ => + javaBLAS.dscal(n, alpha, x.clone, 1) } - if (!vectorBLAS.getClass.equals(classOf[F2jBLAS])) { - benchmark.addCase("vector") { _ => - vectorBLAS.dscal(n, alpha, x, 1) + if (nativeBLAS != javaBLAS) { + benchmark.addCase("native") { _ => + nativeBLAS.dscal(n, alpha, x.clone, 1) } } @@ -165,25 +187,26 @@ object BLASBenchmark extends BenchmarkBase { } runBenchmark("sscal") { - val n = 1e7.toInt + val n = 1e8.toInt val alpha = rnd.nextFloat val x = Array.fill(n) { rnd.nextFloat } - val benchmark = new Benchmark("sscal", n, iters, output = output) + val benchmark = new Benchmark("sscal", n, iters, + warmupTime = 30.seconds, + minTime = 30.seconds, + output = output) benchmark.addCase("f2j") { _ => - f2jBLAS.sscal(n, alpha, x, 1) + f2jBLAS.sscal(n, alpha, x.clone, 1) } - if (!nativeBLAS.getClass.equals(classOf[F2jBLAS])) { - benchmark.addCase("native") { _ => - nativeBLAS.sscal(n, alpha, x, 1) - } + benchmark.addCase("java") { _ => + javaBLAS.sscal(n, alpha, x.clone, 1) } - if (!vectorBLAS.getClass.equals(classOf[F2jBLAS])) { - benchmark.addCase("vector") { _ => - vectorBLAS.sscal(n, alpha, x, 1) + if (nativeBLAS != javaBLAS) { + benchmark.addCase("native") { _ => + nativeBLAS.sscal(n, alpha, x.clone, 1) } } @@ -191,28 +214,29 @@ object BLASBenchmark extends BenchmarkBase { } runBenchmark("dspmv[U]") { - val n = 1e4.toInt + val n = 1e3.toInt val alpha = rnd.nextDouble val a = Array.fill(n * (n + 1) / 2) { rnd.nextDouble } val x = Array.fill(n) { rnd.nextDouble } val beta = rnd.nextDouble val y = Array.fill(n) { rnd.nextDouble } - val benchmark = new Benchmark("dspmv[U]", n, iters, output = output) + val benchmark = new Benchmark("dspmv[U]", n * (n + 1) / 2, iters, + warmupTime = 30.seconds, + minTime = 30.seconds, + output = output) benchmark.addCase("f2j") { _ => - f2jBLAS.dspmv("U", n, alpha, a, x, 1, beta, y, 1) + f2jBLAS.dspmv("U", n, alpha, a, x, 1, beta, y.clone, 1) } - if (!nativeBLAS.getClass.equals(classOf[F2jBLAS])) { - benchmark.addCase("native") { _ => - nativeBLAS.dspmv("U", n, alpha, a, x, 1, beta, y, 1) - } + benchmark.addCase("java") { _ => + javaBLAS.dspmv("U", n, alpha, a, x, 1, beta, y.clone, 1) } - if (!vectorBLAS.getClass.equals(classOf[F2jBLAS])) { - benchmark.addCase("vector") { _ => - vectorBLAS.dspmv("U", n, alpha, a, x, 1, beta, y, 1) + if (nativeBLAS != javaBLAS) { + benchmark.addCase("native") { _ => + nativeBLAS.dspmv("U", n, alpha, a, x, 1, beta, y.clone, 1) } } @@ -220,26 +244,27 @@ object BLASBenchmark extends BenchmarkBase { } runBenchmark("dspr[U]") { - val n = 1e4.toInt + val n = 1e3.toInt val alpha = rnd.nextDouble val x = Array.fill(n) { rnd.nextDouble } val a = Array.fill(n * (n + 1) / 2) { rnd.nextDouble } - val benchmark = new Benchmark("dspr[U]", n, iters, output = output) + val benchmark = new Benchmark("dspr[U]", n * (n + 1) / 2, iters, + warmupTime = 30.seconds, + minTime = 30.seconds, + output = output) benchmark.addCase("f2j") { _ => - f2jBLAS.dspr("U", n, alpha, x, 1, a) + f2jBLAS.dspr("U", n, alpha, x, 1, a.clone) } - if (!nativeBLAS.getClass.equals(classOf[F2jBLAS])) { - benchmark.addCase("native") { _ => - nativeBLAS.dspr("U", n, alpha, x, 1, a) - } + benchmark.addCase("java") { _ => + javaBLAS.dspr("U", n, alpha, x, 1, a.clone) } - if (!vectorBLAS.getClass.equals(classOf[F2jBLAS])) { - benchmark.addCase("vector") { _ => - vectorBLAS.dspr("U", n, alpha, x, 1, a) + if (nativeBLAS != javaBLAS) { + benchmark.addCase("native") { _ => + nativeBLAS.dspr("U", n, alpha, x, 1, a.clone) } } @@ -247,26 +272,27 @@ object BLASBenchmark extends BenchmarkBase { } runBenchmark("dsyr[U]") { - val n = 1e4.toInt + val n = 1e3.toInt val alpha = rnd.nextDouble val x = Array.fill(n) { rnd.nextDouble } val a = Array.fill(n * n) { rnd.nextDouble } - val benchmark = new Benchmark("dsyr[U]", n, iters, output = output) + val benchmark = new Benchmark("dsyr[U]", n * (n + 1) / 2, iters, + warmupTime = 30.seconds, + minTime = 30.seconds, + output = output) benchmark.addCase("f2j") { _ => - f2jBLAS.dsyr("U", n, alpha, x, 1, a, n) + f2jBLAS.dsyr("U", n, alpha, x, 1, a.clone, n) } - if (!nativeBLAS.getClass.equals(classOf[F2jBLAS])) { - benchmark.addCase("native") { _ => - nativeBLAS.dsyr("U", n, alpha, x, 1, a, n) - } + benchmark.addCase("java") { _ => + javaBLAS.dsyr("U", n, alpha, x, 1, a.clone, n) } - if (!vectorBLAS.getClass.equals(classOf[F2jBLAS])) { - benchmark.addCase("vector") { _ => - vectorBLAS.dsyr("U", n, alpha, x, 1, a, n) + if (nativeBLAS != javaBLAS) { + benchmark.addCase("native") { _ => + nativeBLAS.dsyr("U", n, alpha, x, 1, a.clone, n) } } @@ -274,7 +300,7 @@ object BLASBenchmark extends BenchmarkBase { } runBenchmark("dgemv[N]") { - val m = 1e4.toInt + val m = 1e3.toInt val n = 1e3.toInt val alpha = rnd.nextDouble val a = Array.fill(m * n) { rnd.nextDouble } @@ -283,21 +309,22 @@ object BLASBenchmark extends BenchmarkBase { val beta = rnd.nextDouble val y = Array.fill(m) { rnd.nextDouble } - val benchmark = new Benchmark("dgemv[N]", n, iters, output = output) + val benchmark = new Benchmark("dgemv[N]", m * n, iters, + warmupTime = 30.seconds, + minTime = 30.seconds, + output = output) benchmark.addCase("f2j") { _ => - f2jBLAS.dgemv("N", m, n, alpha, a, lda, x, 1, beta, y, 1) + f2jBLAS.dgemv("N", m, n, alpha, a, lda, x, 1, beta, y.clone, 1) } - if (!nativeBLAS.getClass.equals(classOf[F2jBLAS])) { - benchmark.addCase("native") { _ => - nativeBLAS.dgemv("N", m, n, alpha, a, lda, x, 1, beta, y, 1) - } + benchmark.addCase("java") { _ => + javaBLAS.dgemv("N", m, n, alpha, a, lda, x, 1, beta, y.clone, 1) } - if (!vectorBLAS.getClass.equals(classOf[F2jBLAS])) { - benchmark.addCase("vector") { _ => - vectorBLAS.dgemv("N", m, n, alpha, a, lda, x, 1, beta, y, 1) + if (nativeBLAS != javaBLAS) { + benchmark.addCase("native") { _ => + nativeBLAS.dgemv("N", m, n, alpha, a, lda, x, 1, beta, y.clone, 1) } } @@ -305,7 +332,7 @@ object BLASBenchmark extends BenchmarkBase { } runBenchmark("dgemv[T]") { - val m = 1e4.toInt + val m = 1e3.toInt val n = 1e3.toInt val alpha = rnd.nextDouble val a = Array.fill(m * n) { rnd.nextDouble } @@ -314,21 +341,22 @@ object BLASBenchmark extends BenchmarkBase { val beta = rnd.nextDouble val y = Array.fill(n) { rnd.nextDouble } - val benchmark = new Benchmark("dgemv[T]", n, iters, output = output) + val benchmark = new Benchmark("dgemv[T]", m * n, iters, + warmupTime = 30.seconds, + minTime = 30.seconds, + output = output) benchmark.addCase("f2j") { _ => - f2jBLAS.dgemv("T", m, n, alpha, a, lda, x, 1, beta, y, 1) + f2jBLAS.dgemv("T", m, n, alpha, a, lda, x, 1, beta, y.clone, 1) } - if (!nativeBLAS.getClass.equals(classOf[F2jBLAS])) { - benchmark.addCase("native") { _ => - nativeBLAS.dgemv("T", m, n, alpha, a, lda, x, 1, beta, y, 1) - } + benchmark.addCase("java") { _ => + javaBLAS.dgemv("T", m, n, alpha, a, lda, x, 1, beta, y.clone, 1) } - if (!vectorBLAS.getClass.equals(classOf[F2jBLAS])) { - benchmark.addCase("vector") { _ => - vectorBLAS.dgemv("T", m, n, alpha, a, lda, x, 1, beta, y, 1) + if (nativeBLAS != javaBLAS) { + benchmark.addCase("native") { _ => + nativeBLAS.dgemv("T", m, n, alpha, a, lda, x, 1, beta, y.clone, 1) } } @@ -336,7 +364,7 @@ object BLASBenchmark extends BenchmarkBase { } runBenchmark("sgemv[N]") { - val m = 1e4.toInt + val m = 1e3.toInt val n = 1e3.toInt val alpha = rnd.nextFloat val a = Array.fill(m * n) { rnd.nextFloat } @@ -345,21 +373,22 @@ object BLASBenchmark extends BenchmarkBase { val beta = rnd.nextFloat val y = Array.fill(m) { rnd.nextFloat } - val benchmark = new Benchmark("sgemv[N]", n, iters, output = output) + val benchmark = new Benchmark("sgemv[N]", m * n, iters, + warmupTime = 30.seconds, + minTime = 30.seconds, + output = output) benchmark.addCase("f2j") { _ => - f2jBLAS.sgemv("N", m, n, alpha, a, lda, x, 1, beta, y, 1) + f2jBLAS.sgemv("N", m, n, alpha, a, lda, x, 1, beta, y.clone, 1) } - if (!nativeBLAS.getClass.equals(classOf[F2jBLAS])) { - benchmark.addCase("native") { _ => - nativeBLAS.sgemv("N", m, n, alpha, a, lda, x, 1, beta, y, 1) - } + benchmark.addCase("java") { _ => + javaBLAS.sgemv("N", m, n, alpha, a, lda, x, 1, beta, y.clone, 1) } - if (!vectorBLAS.getClass.equals(classOf[F2jBLAS])) { - benchmark.addCase("vector") { _ => - vectorBLAS.sgemv("N", m, n, alpha, a, lda, x, 1, beta, y, 1) + if (nativeBLAS != javaBLAS) { + benchmark.addCase("native") { _ => + nativeBLAS.sgemv("N", m, n, alpha, a, lda, x, 1, beta, y.clone, 1) } } @@ -367,7 +396,7 @@ object BLASBenchmark extends BenchmarkBase { } runBenchmark("sgemv[T]") { - val m = 1e4.toInt + val m = 1e3.toInt val n = 1e3.toInt val alpha = rnd.nextFloat val a = Array.fill(m * n) { rnd.nextFloat } @@ -376,21 +405,22 @@ object BLASBenchmark extends BenchmarkBase { val beta = rnd.nextFloat val y = Array.fill(n) { rnd.nextFloat } - val benchmark = new Benchmark("sgemv[T]", n, iters, output = output) + val benchmark = new Benchmark("sgemv[T]", m * n, iters, + warmupTime = 30.seconds, + minTime = 30.seconds, + output = output) benchmark.addCase("f2j") { _ => - f2jBLAS.sgemv("T", m, n, alpha, a, lda, x, 1, beta, y, 1) + f2jBLAS.sgemv("T", m, n, alpha, a, lda, x, 1, beta, y.clone, 1) } - if (!nativeBLAS.getClass.equals(classOf[F2jBLAS])) { - benchmark.addCase("native") { _ => - nativeBLAS.sgemv("T", m, n, alpha, a, lda, x, 1, beta, y, 1) - } + benchmark.addCase("java") { _ => + javaBLAS.sgemv("T", m, n, alpha, a, lda, x, 1, beta, y.clone, 1) } - if (!vectorBLAS.getClass.equals(classOf[F2jBLAS])) { - benchmark.addCase("vector") { _ => - vectorBLAS.sgemv("T", m, n, alpha, a, lda, x, 1, beta, y, 1) + if (nativeBLAS != javaBLAS) { + benchmark.addCase("native") { _ => + nativeBLAS.sgemv("T", m, n, alpha, a, lda, x, 1, beta, y.clone, 1) } } @@ -399,7 +429,7 @@ object BLASBenchmark extends BenchmarkBase { runBenchmark("dgemm[N,N]") { val m = 1e3.toInt - val n = 1e2.toInt + val n = 1e3.toInt val k = 1e3.toInt val alpha = rnd.nextDouble val a = Array.fill(m * k) { rnd.nextDouble } @@ -410,21 +440,22 @@ object BLASBenchmark extends BenchmarkBase { val c = Array.fill(m * n) { rnd.nextDouble } var ldc = m - val benchmark = new Benchmark("dgemm[N,N]", m*n, iters, output = output) + val benchmark = new Benchmark("dgemm[N,N]", m * n * k, iters, + warmupTime = 30.seconds, + minTime = 30.seconds, + output = output) benchmark.addCase("f2j") { _ => - f2jBLAS.dgemm("N", "N", m, n, k, alpha, a, lda, b, ldb, beta, c, ldc) + f2jBLAS.dgemm("N", "N", m, n, k, alpha, a, lda, b, ldb, beta, c.clone, ldc) } - if (!nativeBLAS.getClass.equals(classOf[F2jBLAS])) { - benchmark.addCase("native") { _ => - nativeBLAS.dgemm("N", "N", m, n, k, alpha, a, lda, b, ldb, beta, c, ldc) - } + benchmark.addCase("java") { _ => + javaBLAS.dgemm("N", "N", m, n, k, alpha, a, lda, b, ldb, beta, c.clone, ldc) } - if (!vectorBLAS.getClass.equals(classOf[F2jBLAS])) { - benchmark.addCase("vector") { _ => - vectorBLAS.dgemm("N", "N", m, n, k, alpha, a, lda, b, ldb, beta, c, ldc) + if (nativeBLAS != javaBLAS) { + benchmark.addCase("native") { _ => + nativeBLAS.dgemm("N", "N", m, n, k, alpha, a, lda, b, ldb, beta, c.clone, ldc) } } @@ -433,7 +464,7 @@ object BLASBenchmark extends BenchmarkBase { runBenchmark("dgemm[N,T]") { val m = 1e3.toInt - val n = 1e2.toInt + val n = 1e3.toInt val k = 1e3.toInt val alpha = rnd.nextDouble val a = Array.fill(m * k) { rnd.nextDouble } @@ -444,21 +475,22 @@ object BLASBenchmark extends BenchmarkBase { val c = Array.fill(m * n) { rnd.nextDouble } var ldc = m - val benchmark = new Benchmark("dgemm[N,T]", m*n, iters, output = output) + val benchmark = new Benchmark("dgemm[N,T]", m * n * k, iters, + warmupTime = 30.seconds, + minTime = 30.seconds, + output = output) benchmark.addCase("f2j") { _ => - f2jBLAS.dgemm("N", "T", m, n, k, alpha, a, lda, b, ldb, beta, c, ldc) + f2jBLAS.dgemm("N", "T", m, n, k, alpha, a, lda, b, ldb, beta, c.clone, ldc) } - if (!nativeBLAS.getClass.equals(classOf[F2jBLAS])) { - benchmark.addCase("native") { _ => - nativeBLAS.dgemm("N", "T", m, n, k, alpha, a, lda, b, ldb, beta, c, ldc) - } + benchmark.addCase("java") { _ => + javaBLAS.dgemm("N", "T", m, n, k, alpha, a, lda, b, ldb, beta, c.clone, ldc) } - if (!vectorBLAS.getClass.equals(classOf[F2jBLAS])) { - benchmark.addCase("vector") { _ => - vectorBLAS.dgemm("N", "T", m, n, k, alpha, a, lda, b, ldb, beta, c, ldc) + if (nativeBLAS != javaBLAS) { + benchmark.addCase("native") { _ => + nativeBLAS.dgemm("N", "T", m, n, k, alpha, a, lda, b, ldb, beta, c.clone, ldc) } } @@ -467,7 +499,7 @@ object BLASBenchmark extends BenchmarkBase { runBenchmark("dgemm[T,N]") { val m = 1e3.toInt - val n = 1e2.toInt + val n = 1e3.toInt val k = 1e3.toInt val alpha = rnd.nextDouble val a = Array.fill(m * k) { rnd.nextDouble } @@ -478,21 +510,197 @@ object BLASBenchmark extends BenchmarkBase { val c = Array.fill(m * n) { rnd.nextDouble } var ldc = m - val benchmark = new Benchmark("dgemm[T,N]", m*n, iters, output = output) + val benchmark = new Benchmark("dgemm[T,N]", m * n * k, iters, + warmupTime = 30.seconds, + minTime = 30.seconds, + output = output) + + benchmark.addCase("f2j") { _ => + f2jBLAS.dgemm("T", "N", m, n, k, alpha, a, lda, b, ldb, beta, c.clone, ldc) + } + + benchmark.addCase("java") { _ => + javaBLAS.dgemm("T", "N", m, n, k, alpha, a, lda, b, ldb, beta, c.clone, ldc) + } + + if (nativeBLAS != javaBLAS) { + benchmark.addCase("native") { _ => + nativeBLAS.dgemm("T", "N", m, n, k, alpha, a, lda, b, ldb, beta, c.clone, ldc) + } + } + + benchmark.run() + } + + runBenchmark("dgemm[T,T]") { + val m = 1e3.toInt + val n = 1e3.toInt + val k = 1e3.toInt + val alpha = rnd.nextDouble + val a = Array.fill(m * k) { rnd.nextDouble } + val lda = k + val b = Array.fill(k * n) { rnd.nextDouble } + val ldb = n + val beta = rnd.nextDouble + val c = Array.fill(m * n) { rnd.nextDouble } + var ldc = m + + val benchmark = new Benchmark("dgemm[T,T]", m * n * k, iters, + warmupTime = 30.seconds, + minTime = 30.seconds, + output = output) + + benchmark.addCase("f2j") { _ => + f2jBLAS.dgemm("T", "T", m, n, k, alpha, a, lda, b, ldb, beta, c.clone, ldc) + } + + benchmark.addCase("java") { _ => + javaBLAS.dgemm("T", "T", m, n, k, alpha, a, lda, b, ldb, beta, c.clone, ldc) + } + + if (nativeBLAS != javaBLAS) { + benchmark.addCase("native") { _ => + nativeBLAS.dgemm("T", "T", m, n, k, alpha, a, lda, b, ldb, beta, c.clone, ldc) + } + } + + benchmark.run() + } + + runBenchmark("sgemm[N,N]") { + val m = 1e3.toInt + val n = 1e3.toInt + val k = 1e3.toInt + val alpha = rnd.nextFloat + val a = Array.fill(m * k) { rnd.nextFloat } + val lda = m + val b = Array.fill(k * n) { rnd.nextFloat } + val ldb = k + val beta = rnd.nextFloat + val c = Array.fill(m * n) { rnd.nextFloat } + var ldc = m + + val benchmark = new Benchmark("sgemm[N,N]", m * n * k, iters, + warmupTime = 30.seconds, + minTime = 30.seconds, + output = output) + + benchmark.addCase("f2j") { _ => + f2jBLAS.sgemm("N", "N", m, n, k, alpha, a, lda, b, ldb, beta, c.clone, ldc) + } + + benchmark.addCase("java") { _ => + javaBLAS.sgemm("N", "N", m, n, k, alpha, a, lda, b, ldb, beta, c.clone, ldc) + } + + if (nativeBLAS != javaBLAS) { + benchmark.addCase("native") { _ => + nativeBLAS.sgemm("N", "N", m, n, k, alpha, a, lda, b, ldb, beta, c.clone, ldc) + } + } + + benchmark.run() + } + + runBenchmark("sgemm[N,T]") { + val m = 1e3.toInt + val n = 1e3.toInt + val k = 1e3.toInt + val alpha = rnd.nextFloat + val a = Array.fill(m * k) { rnd.nextFloat } + val lda = m + val b = Array.fill(k * n) { rnd.nextFloat } + val ldb = n + val beta = rnd.nextFloat + val c = Array.fill(m * n) { rnd.nextFloat } + var ldc = m + + val benchmark = new Benchmark("sgemm[N,T]", m * n * k, iters, + warmupTime = 30.seconds, + minTime = 30.seconds, + output = output) + + benchmark.addCase("f2j") { _ => + f2jBLAS.sgemm("N", "T", m, n, k, alpha, a, lda, b, ldb, beta, c.clone, ldc) + } + + benchmark.addCase("java") { _ => + javaBLAS.sgemm("N", "T", m, n, k, alpha, a, lda, b, ldb, beta, c.clone, ldc) + } + + if (nativeBLAS != javaBLAS) { + benchmark.addCase("native") { _ => + nativeBLAS.sgemm("N", "T", m, n, k, alpha, a, lda, b, ldb, beta, c.clone, ldc) + } + } + + benchmark.run() + } + + runBenchmark("sgemm[T,N]") { + val m = 1e3.toInt + val n = 1e3.toInt + val k = 1e3.toInt + val alpha = rnd.nextFloat + val a = Array.fill(m * k) { rnd.nextFloat } + val lda = k + val b = Array.fill(k * n) { rnd.nextFloat } + val ldb = k + val beta = rnd.nextFloat + val c = Array.fill(m * n) { rnd.nextFloat } + var ldc = m + + val benchmark = new Benchmark("sgemm[T,N]", m * n * k, iters, + warmupTime = 30.seconds, + minTime = 30.seconds, + output = output) benchmark.addCase("f2j") { _ => - f2jBLAS.dgemm("T", "N", m, n, k, alpha, a, lda, b, ldb, beta, c, ldc) + f2jBLAS.sgemm("T", "N", m, n, k, alpha, a, lda, b, ldb, beta, c.clone, ldc) + } + + benchmark.addCase("java") { _ => + javaBLAS.sgemm("T", "N", m, n, k, alpha, a, lda, b, ldb, beta, c.clone, ldc) } - if (!nativeBLAS.getClass.equals(classOf[F2jBLAS])) { + if (nativeBLAS != javaBLAS) { benchmark.addCase("native") { _ => - nativeBLAS.dgemm("T", "N", m, n, k, alpha, a, lda, b, ldb, beta, c, ldc) + nativeBLAS.sgemm("T", "N", m, n, k, alpha, a, lda, b, ldb, beta, c.clone, ldc) } } - if (!vectorBLAS.getClass.equals(classOf[F2jBLAS])) { - benchmark.addCase("vector") { _ => - vectorBLAS.dgemm("T", "N", m, n, k, alpha, a, lda, b, ldb, beta, c, ldc) + benchmark.run() + } + + runBenchmark("sgemm[T,T]") { + val m = 1e3.toInt + val n = 1e3.toInt + val k = 1e3.toInt + val alpha = rnd.nextFloat + val a = Array.fill(m * k) { rnd.nextFloat } + val lda = k + val b = Array.fill(k * n) { rnd.nextFloat } + val ldb = n + val beta = rnd.nextFloat + val c = Array.fill(m * n) { rnd.nextFloat } + var ldc = m + + val benchmark = new Benchmark("sgemm[T,T]", m * n * k, iters, + warmupTime = 30.seconds, + minTime = 30.seconds, + output = output) + + benchmark.addCase("f2j") { _ => + f2jBLAS.sgemm("T", "T", m, n, k, alpha, a, lda, b, ldb, beta, c.clone, ldc) + } + + benchmark.addCase("java") { _ => + javaBLAS.sgemm("T", "T", m, n, k, alpha, a, lda, b, ldb, beta, c.clone, ldc) + } + + if (nativeBLAS != javaBLAS) { + benchmark.addCase("native") { _ => + nativeBLAS.sgemm("T", "T", m, n, k, alpha, a, lda, b, ldb, beta, c.clone, ldc) } } diff --git a/mllib/pom.xml b/mllib/pom.xml index f5b5a979e35b8..626ac85ce1ceb 100644 --- a/mllib/pom.xml +++ b/mllib/pom.xml @@ -142,6 +142,19 @@ test + + dev.ludovic.netlib + blas + + + dev.ludovic.netlib + lapack + + + dev.ludovic.netlib + arpack + + diff --git a/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala b/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala index a2c376d80e03a..d2cfedcc33e88 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala @@ -22,7 +22,6 @@ import java.lang.{Iterable => JavaIterable} import scala.collection.JavaConverters._ import scala.collection.mutable -import com.github.fommil.netlib.BLAS.{getInstance => blas} import com.google.common.collect.{Ordering => GuavaOrdering} import org.json4s.DefaultFormats import org.json4s.JsonDSL._ @@ -34,6 +33,7 @@ import org.apache.spark.api.java.JavaRDD import org.apache.spark.broadcast.Broadcast import org.apache.spark.internal.Logging import org.apache.spark.internal.config.Kryo.KRYO_SERIALIZER_MAX_BUFFER_SIZE +import org.apache.spark.ml.linalg.BLAS import org.apache.spark.mllib.linalg.{Vector, Vectors} import org.apache.spark.mllib.util.{Loader, Saveable} import org.apache.spark.rdd._ @@ -401,18 +401,18 @@ class Word2Vec extends Serializable with Logging { val inner = bcVocab.value(word).point(d) val l2 = inner * vectorSize // Propagate hidden -> output - var f = blas.sdot(vectorSize, syn0, l1, 1, syn1, l2, 1) + var f = BLAS.nativeBLAS.sdot(vectorSize, syn0, l1, 1, syn1, l2, 1) if (f > -MAX_EXP && f < MAX_EXP) { val ind = ((f + MAX_EXP) * (EXP_TABLE_SIZE / MAX_EXP / 2.0)).toInt f = expTable.value(ind) val g = ((1 - bcVocab.value(word).code(d) - f) * alpha).toFloat - blas.saxpy(vectorSize, g, syn1, l2, 1, neu1e, 0, 1) - blas.saxpy(vectorSize, g, syn0, l1, 1, syn1, l2, 1) + BLAS.nativeBLAS.saxpy(vectorSize, g, syn1, l2, 1, neu1e, 0, 1) + BLAS.nativeBLAS.saxpy(vectorSize, g, syn0, l1, 1, syn1, l2, 1) syn1Modify(inner) += 1 } d += 1 } - blas.saxpy(vectorSize, 1.0f, neu1e, 0, 1, syn0, l1, 1) + BLAS.nativeBLAS.saxpy(vectorSize, 1.0f, neu1e, 0, 1, syn0, l1, 1) syn0Modify(lastWord) += 1 } } @@ -448,10 +448,10 @@ class Word2Vec extends Serializable with Logging { (id, (vec, 1)) } }.reduceByKey { (vc1, vc2) => - blas.saxpy(vectorSize, 1.0f, vc2._1, 1, vc1._1, 1) + BLAS.nativeBLAS.saxpy(vectorSize, 1.0f, vc2._1, 1, vc1._1, 1) (vc1._1, vc1._2 + vc2._2) }.map { case (id, (vec, count)) => - blas.sscal(vectorSize, 1.0f / count, vec, 1) + BLAS.nativeBLAS.sscal(vectorSize, 1.0f / count, vec, 1) (id, vec) }.collect() var i = 0 @@ -511,7 +511,7 @@ class Word2VecModel private[spark] ( private lazy val wordVecInvNorms: Array[Float] = { val size = vectorSize Array.tabulate(numWords) { i => - val norm = blas.snrm2(size, wordVectors, i * size, 1) + val norm = BLAS.nativeBLAS.snrm2(size, wordVectors, i * size, 1) if (norm != 0) 1 / norm else 0.0F } } @@ -587,7 +587,7 @@ class Word2VecModel private[spark] ( val localVectorSize = vectorSize val floatVec = vector.map(_.toFloat) - val vecNorm = blas.snrm2(localVectorSize, floatVec, 1) + val vecNorm = BLAS.nativeBLAS.snrm2(localVectorSize, floatVec, 1) val localWordList = wordList val localNumWords = numWords @@ -597,11 +597,11 @@ class Word2VecModel private[spark] ( .take(num) .toArray } else { - // Normalize input vector before blas.sgemv to avoid Inf value - blas.sscal(localVectorSize, 1 / vecNorm, floatVec, 0, 1) + // Normalize input vector before BLAS.nativeBLAS.sgemv to avoid Inf value + BLAS.nativeBLAS.sscal(localVectorSize, 1 / vecNorm, floatVec, 0, 1) val cosineVec = Array.ofDim[Float](localNumWords) - blas.sgemv("T", localVectorSize, localNumWords, 1.0F, wordVectors, localVectorSize, + BLAS.nativeBLAS.sgemv("T", localVectorSize, localNumWords, 1.0F, wordVectors, localVectorSize, floatVec, 1, 0.0F, cosineVec, 1) val localWordVecInvNorms = wordVecInvNorms diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/ARPACK.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/ARPACK.scala new file mode 100644 index 0000000000000..fb0f6ddd470b4 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/ARPACK.scala @@ -0,0 +1,46 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.linalg + +import dev.ludovic.netlib.{ARPACK => NetlibARPACK, + JavaARPACK => NetlibJavaARPACK, + NativeARPACK => NetlibNativeARPACK} + +/** + * ARPACK routines for MLlib's vectors and matrices. + */ +private[spark] object ARPACK extends Serializable { + + @transient private var _javaARPACK: NetlibARPACK = _ + @transient private var _nativeARPACK: NetlibARPACK = _ + + private[spark] def javaARPACK: NetlibARPACK = { + if (_javaARPACK == null) { + _javaARPACK = NetlibJavaARPACK.getInstance + } + _javaARPACK + } + + private[spark] def nativeARPACK: NetlibARPACK = { + if (_nativeARPACK == null) { + _nativeARPACK = + try { NetlibNativeARPACK.getInstance } catch { case _: Throwable => javaARPACK } + } + _nativeARPACK + } +} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/BLAS.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/BLAS.scala index bd60364326e28..e38cfe4e18d40 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/BLAS.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/BLAS.scala @@ -17,8 +17,9 @@ package org.apache.spark.mllib.linalg -import com.github.fommil.netlib.{BLAS => NetlibBLAS, F2jBLAS} -import com.github.fommil.netlib.BLAS.{getInstance => NativeBLAS} +import dev.ludovic.netlib.{BLAS => NetlibBLAS, + JavaBLAS => NetlibJavaBLAS, + NativeBLAS => NetlibNativeBLAS} import org.apache.spark.internal.Logging @@ -27,21 +28,30 @@ import org.apache.spark.internal.Logging */ private[spark] object BLAS extends Serializable with Logging { - @transient private var _f2jBLAS: NetlibBLAS = _ + @transient private var _javaBLAS: NetlibBLAS = _ @transient private var _nativeBLAS: NetlibBLAS = _ private val nativeL1Threshold: Int = 256 - // For level-1 function dspmv, use f2jBLAS for better performance. - private[mllib] def f2jBLAS: NetlibBLAS = { - if (_f2jBLAS == null) { - _f2jBLAS = new F2jBLAS + // For level-1 function dspmv, use javaBLAS for better performance. + private[spark] def javaBLAS: NetlibBLAS = { + if (_javaBLAS == null) { + _javaBLAS = NetlibJavaBLAS.getInstance } - _f2jBLAS + _javaBLAS } - private[mllib] def getBLAS(vectorSize: Int): NetlibBLAS = { + // For level-3 routines, we use the native BLAS. + private[spark] def nativeBLAS: NetlibBLAS = { + if (_nativeBLAS == null) { + _nativeBLAS = + try { NetlibNativeBLAS.getInstance } catch { case _: Throwable => javaBLAS } + } + _nativeBLAS + } + + private[spark] def getBLAS(vectorSize: Int): NetlibBLAS = { if (vectorSize < nativeL1Threshold) { - f2jBLAS + javaBLAS } else { nativeBLAS } @@ -237,14 +247,6 @@ private[spark] object BLAS extends Serializable with Logging { } } - // For level-3 routines, we use the native BLAS. - private[mllib] def nativeBLAS: NetlibBLAS = { - if (_nativeBLAS == null) { - _nativeBLAS = NativeBLAS - } - _nativeBLAS - } - /** * Adds alpha * v * v.t to a matrix in-place. This is the same as BLAS's ?SPR. * @@ -263,7 +265,7 @@ private[spark] object BLAS extends Serializable with Logging { val n = v.size v match { case DenseVector(values) => - NativeBLAS.dspr("U", n, alpha, values, 1, U) + nativeBLAS.dspr("U", n, alpha, values, 1, U) case SparseVector(size, indices, values) => val nnz = indices.length var colStartIdx = 0 diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/CholeskyDecomposition.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/CholeskyDecomposition.scala index 68771f1afbe8c..f06ea9418f252 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/CholeskyDecomposition.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/CholeskyDecomposition.scala @@ -17,7 +17,6 @@ package org.apache.spark.mllib.linalg -import com.github.fommil.netlib.LAPACK.{getInstance => lapack} import org.netlib.util.intW import org.apache.spark.ml.optim.SingularMatrixException @@ -37,7 +36,7 @@ private[spark] object CholeskyDecomposition { def solve(A: Array[Double], bx: Array[Double]): Array[Double] = { val k = bx.length val info = new intW(0) - lapack.dppsv("U", k, 1, A, bx, k, info) + LAPACK.nativeLAPACK.dppsv("U", k, 1, A, bx, k, info) checkReturnValue(info, "dppsv") bx } @@ -52,7 +51,7 @@ private[spark] object CholeskyDecomposition { */ def inverse(UAi: Array[Double], k: Int): Array[Double] = { val info = new intW(0) - lapack.dpptri("U", k, UAi, info) + LAPACK.nativeLAPACK.dpptri("U", k, UAi, info) checkReturnValue(info, "dpptri") UAi } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/EigenValueDecomposition.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/EigenValueDecomposition.scala index 4c71cd649621e..2cbf5d09dc56f 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/EigenValueDecomposition.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/EigenValueDecomposition.scala @@ -18,7 +18,6 @@ package org.apache.spark.mllib.linalg import breeze.linalg.{DenseMatrix => BDM, DenseVector => BDV} -import com.github.fommil.netlib.ARPACK import org.netlib.util.{doubleW, intW} /** @@ -51,8 +50,6 @@ private[mllib] object EigenValueDecomposition { // TODO: remove this function and use eigs in breeze when switching breeze version require(n > k, s"Number of required eigenvalues $k must be smaller than matrix dimension $n") - val arpack = ARPACK.getInstance() - // tolerance used in stopping criterion val tolW = new doubleW(tol) // number of desired eigenvalues, 0 < nev < n @@ -87,8 +84,8 @@ private[mllib] object EigenValueDecomposition { val ipntr = new Array[Int](11) // call ARPACK's reverse communication, first iteration with ido = 0 - arpack.dsaupd(ido, bmat, n, which, nev.`val`, tolW, resid, ncv, v, n, iparam, ipntr, workd, - workl, workl.length, info) + ARPACK.nativeARPACK.dsaupd(ido, bmat, n, which, nev.`val`, tolW, resid, ncv, + v, n, iparam, ipntr, workd, workl, workl.length, info) val w = BDV(workd) @@ -105,8 +102,8 @@ private[mllib] object EigenValueDecomposition { val y = w.slice(outputOffset, outputOffset + n) y := mul(x) // call ARPACK's reverse communication - arpack.dsaupd(ido, bmat, n, which, nev.`val`, tolW, resid, ncv, v, n, iparam, ipntr, - workd, workl, workl.length, info) + ARPACK.nativeARPACK.dsaupd(ido, bmat, n, which, nev.`val`, tolW, resid, ncv, + v, n, iparam, ipntr, workd, workl, workl.length, info) } if (info.`val` != 0) { @@ -127,8 +124,8 @@ private[mllib] object EigenValueDecomposition { val z = java.util.Arrays.copyOfRange(v, 0, nev.`val` * n) // call ARPACK's post-processing for eigenvectors - arpack.dseupd(true, "A", select, d, z, n, 0.0, bmat, n, which, nev, tol, resid, ncv, v, n, - iparam, ipntr, workd, workl, workl.length, info) + ARPACK.nativeARPACK.dseupd(true, "A", select, d, z, n, 0.0, bmat, n, which, nev, tol, resid, + ncv, v, n, iparam, ipntr, workd, workl, workl.length, info) // number of computed eigenvalues, might be smaller than k val computed = iparam(4) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/LAPACK.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/LAPACK.scala new file mode 100644 index 0000000000000..4d25aed2835cb --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/LAPACK.scala @@ -0,0 +1,46 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.linalg + +import dev.ludovic.netlib.{JavaLAPACK => NetlibJavaLAPACK, + LAPACK => NetlibLAPACK, + NativeLAPACK => NetlibNativeLAPACK} + +/** + * LAPACK routines for MLlib's vectors and matrices. + */ +private[spark] object LAPACK extends Serializable { + + @transient private var _javaLAPACK: NetlibLAPACK = _ + @transient private var _nativeLAPACK: NetlibLAPACK = _ + + private[spark] def javaLAPACK: NetlibLAPACK = { + if (_javaLAPACK == null) { + _javaLAPACK = NetlibJavaLAPACK.getInstance + } + _javaLAPACK + } + + private[spark] def nativeLAPACK: NetlibLAPACK = { + if (_nativeLAPACK == null) { + _nativeLAPACK = + try { NetlibNativeLAPACK.getInstance } catch { case _: Throwable => javaLAPACK } + } + _nativeLAPACK + } +} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala index 57edc965112ef..e4f64b4e34864 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala @@ -23,7 +23,6 @@ import scala.collection.mutable.{ArrayBuffer, ArrayBuilder => MArrayBuilder, Has import scala.language.implicitConversions import breeze.linalg.{CSCMatrix => BSM, DenseMatrix => BDM, Matrix => BM} -import com.github.fommil.netlib.BLAS.{getInstance => blas} import org.apache.spark.annotation.Since import org.apache.spark.ml.{linalg => newlinalg} @@ -427,7 +426,7 @@ class DenseMatrix @Since("1.3.0") ( if (isTransposed) { Iterator.tabulate(numCols) { j => val col = new Array[Double](numRows) - blas.dcopy(numRows, values, j, numCols, col, 0, 1) + BLAS.nativeBLAS.dcopy(numRows, values, j, numCols, col, 0, 1) new DenseVector(col) } } else { diff --git a/mllib/src/main/scala/org/apache/spark/mllib/optimization/NNLS.scala b/mllib/src/main/scala/org/apache/spark/mllib/optimization/NNLS.scala index 86632ae335957..e070d605b1647 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/optimization/NNLS.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/optimization/NNLS.scala @@ -19,7 +19,7 @@ package org.apache.spark.mllib.optimization import java.{util => ju} -import com.github.fommil.netlib.BLAS.{getInstance => blas} +import org.apache.spark.ml.linalg.BLAS /** * Object used to solve nonnegative least squares problems using a modified @@ -75,10 +75,10 @@ private[spark] object NNLS { // find the optimal unconstrained step def steplen(dir: Array[Double], res: Array[Double]): Double = { - val top = blas.ddot(n, dir, 1, res, 1) - blas.dgemv("N", n, n, 1.0, ata, n, dir, 1, 0.0, scratch, 1) + val top = BLAS.nativeBLAS.ddot(n, dir, 1, res, 1) + BLAS.nativeBLAS.dgemv("N", n, n, 1.0, ata, n, dir, 1, 0.0, scratch, 1) // Push the denominator upward very slightly to avoid infinities and silliness - top / (blas.ddot(n, scratch, 1, dir, 1) + 1e-20) + top / (BLAS.nativeBLAS.ddot(n, scratch, 1, dir, 1) + 1e-20) } // stopping condition @@ -103,9 +103,9 @@ private[spark] object NNLS { var i = 0 while (iterno < iterMax) { // find the residual - blas.dgemv("N", n, n, 1.0, ata, n, x, 1, 0.0, res, 1) - blas.daxpy(n, -1.0, atb, 1, res, 1) - blas.dcopy(n, res, 1, grad, 1) + BLAS.nativeBLAS.dgemv("N", n, n, 1.0, ata, n, x, 1, 0.0, res, 1) + BLAS.nativeBLAS.daxpy(n, -1.0, atb, 1, res, 1) + BLAS.nativeBLAS.dcopy(n, res, 1, grad, 1) // project the gradient i = 0 @@ -115,28 +115,28 @@ private[spark] object NNLS { } i = i + 1 } - val ngrad = blas.ddot(n, grad, 1, grad, 1) + val ngrad = BLAS.nativeBLAS.ddot(n, grad, 1, grad, 1) - blas.dcopy(n, grad, 1, dir, 1) + BLAS.nativeBLAS.dcopy(n, grad, 1, dir, 1) // use a CG direction under certain conditions var step = steplen(grad, res) var ndir = 0.0 - val nx = blas.ddot(n, x, 1, x, 1) + val nx = BLAS.nativeBLAS.ddot(n, x, 1, x, 1) if (iterno > lastWall + 1) { val alpha = ngrad / lastNorm - blas.daxpy(n, alpha, lastDir, 1, dir, 1) + BLAS.nativeBLAS.daxpy(n, alpha, lastDir, 1, dir, 1) val dstep = steplen(dir, res) - ndir = blas.ddot(n, dir, 1, dir, 1) + ndir = BLAS.nativeBLAS.ddot(n, dir, 1, dir, 1) if (stop(dstep, ndir, nx)) { // reject the CG step if it could lead to premature termination - blas.dcopy(n, grad, 1, dir, 1) - ndir = blas.ddot(n, dir, 1, dir, 1) + BLAS.nativeBLAS.dcopy(n, grad, 1, dir, 1) + ndir = BLAS.nativeBLAS.ddot(n, dir, 1, dir, 1) } else { step = dstep } } else { - ndir = blas.ddot(n, dir, 1, dir, 1) + ndir = BLAS.nativeBLAS.ddot(n, dir, 1, dir, 1) } // terminate? @@ -166,7 +166,7 @@ private[spark] object NNLS { } iterno = iterno + 1 - blas.dcopy(n, dir, 1, lastDir, 1) + BLAS.nativeBLAS.dcopy(n, dir, 1, lastDir, 1) lastNorm = ngrad } x.clone diff --git a/mllib/src/main/scala/org/apache/spark/mllib/recommendation/MatrixFactorizationModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/recommendation/MatrixFactorizationModel.scala index b1be5225ce51f..3276513213f5d 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/recommendation/MatrixFactorizationModel.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/recommendation/MatrixFactorizationModel.scala @@ -21,7 +21,6 @@ import java.io.IOException import java.lang.{Integer => JavaInteger} import com.clearspring.analytics.stream.cardinality.HyperLogLogPlus -import com.github.fommil.netlib.BLAS.{getInstance => blas} import com.google.common.collect.{Ordering => GuavaOrdering} import org.apache.hadoop.fs.Path import org.json4s._ @@ -85,7 +84,7 @@ class MatrixFactorizationModel @Since("0.8.0") ( val userVector = userFeatureSeq.head val productVector = productFeatureSeq.head - blas.ddot(rank, userVector, 1, productVector, 1) + BLAS.nativeBLAS.ddot(rank, userVector, 1, productVector, 1) } /** @@ -136,7 +135,7 @@ class MatrixFactorizationModel @Since("0.8.0") ( } users.join(productFeatures).map { case (product, ((user, uFeatures), pFeatures)) => - Rating(user, product, blas.ddot(uFeatures.length, uFeatures, 1, pFeatures, 1)) + Rating(user, product, BLAS.nativeBLAS.ddot(uFeatures.length, uFeatures, 1, pFeatures, 1)) } } else { val products = productFeatures.join(usersProducts.map(_.swap)).map { @@ -144,7 +143,7 @@ class MatrixFactorizationModel @Since("0.8.0") ( } products.join(userFeatures).map { case (user, ((product, pFeatures), uFeatures)) => - Rating(user, product, blas.ddot(uFeatures.length, uFeatures, 1, pFeatures, 1)) + Rating(user, product, BLAS.nativeBLAS.ddot(uFeatures.length, uFeatures, 1, pFeatures, 1)) } } } @@ -263,7 +262,7 @@ object MatrixFactorizationModel extends Loader[MatrixFactorizationModel] { recommendableFeatures: RDD[(Int, Array[Double])], num: Int): Array[(Int, Double)] = { val scored = recommendableFeatures.map { case (id, features) => - (id, blas.ddot(features.length, recommendToFeatures, 1, features, 1)) + (id, BLAS.nativeBLAS.ddot(features.length, recommendToFeatures, 1, features, 1)) } scored.top(num)(Ordering.by(_._2)) } @@ -320,7 +319,7 @@ object MatrixFactorizationModel extends Loader[MatrixFactorizationModel] { Iterator.range(0, m).flatMap { i => // scores = i-th vec in srcMat * dstMat - BLAS.f2jBLAS.dgemv("T", rank, n, 1.0F, dstMat, 0, rank, + BLAS.javaBLAS.dgemv("T", rank, n, 1.0F, dstMat, 0, rank, srcMat, i * rank, 1, 0.0F, scores, 0, 1) val srcId = srcIds(i) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/stat/KernelDensity.scala b/mllib/src/main/scala/org/apache/spark/mllib/stat/KernelDensity.scala index f253963270bc4..f0236f0528a21 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/stat/KernelDensity.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/stat/KernelDensity.scala @@ -17,10 +17,9 @@ package org.apache.spark.mllib.stat -import com.github.fommil.netlib.BLAS.{getInstance => blas} - import org.apache.spark.annotation.Since import org.apache.spark.api.java.JavaRDD +import org.apache.spark.ml.linalg.BLAS import org.apache.spark.rdd.RDD /** @@ -99,10 +98,10 @@ class KernelDensity extends Serializable { (x._1, x._2 + 1) }, (x, y) => { - blas.daxpy(n, 1.0, y._1, 1, x._1, 1) + BLAS.nativeBLAS.daxpy(n, 1.0, y._1, 1, x._1, 1) (x._1, x._2 + y._2) }) - blas.dscal(n, 1.0 / count, densities, 1) + BLAS.nativeBLAS.dscal(n, 1.0 / count, densities, 1) densities } } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/treeEnsembleModels.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/treeEnsembleModels.scala index c5069277fad68..1f879a4d9dfbb 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/treeEnsembleModels.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/treeEnsembleModels.scala @@ -19,7 +19,6 @@ package org.apache.spark.mllib.tree.model import scala.collection.mutable -import com.github.fommil.netlib.BLAS.{getInstance => blas} import org.json4s._ import org.json4s.JsonDSL._ import org.json4s.jackson.JsonMethods._ @@ -28,6 +27,7 @@ import org.apache.spark.SparkContext import org.apache.spark.annotation.Since import org.apache.spark.api.java.JavaRDD import org.apache.spark.internal.Logging +import org.apache.spark.ml.linalg.BLAS import org.apache.spark.mllib.linalg.Vector import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.tree.configuration.Algo @@ -280,7 +280,7 @@ private[tree] sealed class TreeEnsembleModel( */ private def predictBySumming(features: Vector): Double = { val treePredictions = trees.map(_.predict(features)) - blas.ddot(numTrees, treePredictions, 1, treeWeights, 1) + BLAS.nativeBLAS.ddot(numTrees, treePredictions, 1, treeWeights, 1) } /** diff --git a/mllib/src/main/scala/org/apache/spark/mllib/util/SVMDataGenerator.scala b/mllib/src/main/scala/org/apache/spark/mllib/util/SVMDataGenerator.scala index 9fffa508afbfb..0f99cef665eaf 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/util/SVMDataGenerator.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/util/SVMDataGenerator.scala @@ -19,10 +19,9 @@ package org.apache.spark.mllib.util import scala.util.Random -import com.github.fommil.netlib.BLAS.{getInstance => blas} - import org.apache.spark.SparkContext import org.apache.spark.annotation.Since +import org.apache.spark.ml.linalg.BLAS import org.apache.spark.mllib.linalg.Vectors import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.rdd.RDD @@ -61,7 +60,8 @@ object SVMDataGenerator { val x = Array.fill[Double](nfeatures) { rnd.nextDouble() * 2.0 - 1.0 } - val yD = blas.ddot(trueWeights.length, x, 1, trueWeights, 1) + rnd.nextGaussian() * 0.1 + val yD = BLAS.nativeBLAS.ddot(trueWeights.length, x, 1, trueWeights, 1) + + rnd.nextGaussian() * 0.1 val y = if (yD < 0) 0.0 else 1.0 LabeledPoint(y, Vectors.dense(x)) } diff --git a/mllib/src/test/scala/org/apache/spark/mllib/linalg/BLASSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/linalg/BLASSuite.scala index 12ab2ac3cc698..91d1e9a44791e 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/linalg/BLASSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/linalg/BLASSuite.scala @@ -24,7 +24,7 @@ import org.apache.spark.mllib.util.TestingUtils._ class BLASSuite extends SparkFunSuite { test("nativeL1Threshold") { - assert(getBLAS(128) == BLAS.f2jBLAS) + assert(getBLAS(128) == BLAS.javaBLAS) assert(getBLAS(256) == BLAS.nativeBLAS) assert(getBLAS(512) == BLAS.nativeBLAS) } diff --git a/pom.xml b/pom.xml index 26f61ab5b2923..9402fd45284ad 100644 --- a/pom.xml +++ b/pom.xml @@ -172,6 +172,7 @@ 2.12.2 1.1.8.2 1.1.2 + 1.3.2 1.15 1.20 2.8.0 @@ -2455,6 +2456,21 @@ commons-cli ${commons-cli.version} + + dev.ludovic.netlib + blas + ${netlib.ludovic.dev.version} + + + dev.ludovic.netlib + lapack + ${netlib.ludovic.dev.version} + + + dev.ludovic.netlib + arpack + ${netlib.ludovic.dev.version} + diff --git a/project/SparkBuild.scala b/project/SparkBuild.scala index b872668db7658..906065ca09048 100644 --- a/project/SparkBuild.scala +++ b/project/SparkBuild.scala @@ -294,7 +294,7 @@ object SparkBuild extends PomBuild { javaOptions ++= { val versionParts = System.getProperty("java.version").split("[+.\\-]+", 3) var major = versionParts(0).toInt - if (major >= 16) Seq("--add-modules=jdk.incubator.vector") else Seq.empty + if (major >= 16) Seq("--add-modules=jdk.incubator.vector,jdk.incubator.foreign", "-Dforeign.restricted=warn") else Seq.empty }, (Compile / doc / javacOptions) ++= { diff --git a/python/pyspark/ml/recommendation.py b/python/pyspark/ml/recommendation.py index 28c4499f779ec..5bc1801a0c957 100644 --- a/python/pyspark/ml/recommendation.py +++ b/python/pyspark/ml/recommendation.py @@ -260,9 +260,9 @@ class ALS(JavaEstimator, _ALSParams, JavaMLWritable, JavaMLReadable): >>> test = spark.createDataFrame([(0, 2), (1, 0), (2, 0)], ["user", "item"]) >>> predictions = sorted(model.transform(test).collect(), key=lambda r: r[0]) >>> predictions[0] - Row(user=0, item=2, newPrediction=0.692910...) + Row(user=0, item=2, newPrediction=0.69291...) >>> predictions[1] - Row(user=1, item=0, newPrediction=3.473569...) + Row(user=1, item=0, newPrediction=3.47356...) >>> predictions[2] Row(user=2, item=0, newPrediction=-0.899198...) >>> user_recs = model.recommendForAllUsers(3)