Skip to content

Commit

Permalink
[SPARK-35150][ML] Accelerate fallback BLAS with dev.ludovic.netlib
Browse files Browse the repository at this point in the history
### What changes were proposed in this pull request?

Following #30810, I've continued looking for ways to accelerate the usage of BLAS in Spark. With this PR, I integrate work done in the [`dev.ludovic.netlib`](https://github.com/luhenry/netlib/) Maven package.

The `dev.ludovic.netlib` library wraps the original `com.github.fommil.netlib` library and focus on accelerating the linear algebra routines in use in Spark. When running the `org.apache.spark.ml.linalg.BLASBenchmark` benchmarking suite, I get the results at [1] on an Intel machine. Moreover, this library is thoroughly tested to return the exact same results as the reference implementation.

Under the hood, it reimplements the necessary algorithms in pure autovectorization-friendly Java 8, as well as takes advantage of the Vector API and Foreign Linker API introduced in JDK 16 when available.

A table summarising which version gets loaded in which case:

```
|                       | BLAS.nativeBLAS                                    | BLAS.javaBLAS                                      |
| --------------------- | -------------------------------------------------- | -------------------------------------------------- |
| with -Pnetlib-lgpl    | 1. dev.ludovic.netlib.blas.NetlibNativeBLAS, a     | 1. dev.ludovic.netlib.blas.VectorizedBLAS          |
|                       |     wrapper for com.github.fommil:all              |    (JDK16+, relies on the Vector API, requires     |
|                       | 2. dev.ludovic.netlib.blas.ForeignBLAS (JDK16+,    |     `--add-modules=jdk.incubator.vector` on JDK16) |
|                       |    relies on the Foreign Linker API, requires      | 2. dev.ludovic.netlib.blas.Java11BLAS (JDK11+)     |
|                       |    `--add-modules=jdk.incubator.foreign            | 3. dev.ludovic.netlib.blas.JavaBLAS                |
|                       |     -Dforeign.restricted=warn`)                    | 4. dev.ludovic.netlib.blas.NetlibF2jBLAS, a        |
|                       | 3. fails to load, falls back to BLAS.javaBLAS in   |     wrapper for com.github.fommil:core             |
|                       |     org.apache.spark.ml.linalg.BLAS                |                                                    |
| --------------------- | -------------------------------------------------- | -------------------------------------------------- |
| without -Pnetlib-lgpl | 1. dev.ludovic.netlib.blas.ForeignBLAS (JDK16+,    | 1. dev.ludovic.netlib.blas.VectorizedBLAS          |
|                       |    relies on the Foreign Linker API, requires      |    (JDK16+, relies on the Vector API, requires     |
|                       |    `--add-modules=jdk.incubator.foreign            |     `--add-modules=jdk.incubator.vector` on JDK16) |
|                       |     -Dforeign.restricted=warn`)                    | 2. dev.ludovic.netlib.blas.Java11BLAS (JDK11+)     |
|                       | 2. fails to load, falls back to BLAS.javaBLAS in   | 3. dev.ludovic.netlib.blas.JavaBLAS                |
|                       |     org.apache.spark.ml.linalg.BLAS                | 4. dev.ludovic.netlib.blas.NetlibF2jBLAS, a        |
|                       |                                                    |     wrapper for com.github.fommil:core             |
| --------------------- | -------------------------------------------------- | -------------------------------------------------- |
```

### Why are the changes needed?

Accelerates linear algebra operations when the pure-java fallback method is in use. Transparently falls back to native implementation (OpenBLAS, MKL) when available.

### Does this PR introduce _any_ user-facing change?

No, all changes are transparent to the user.

### How was this patch tested?

The `dev.ludovic.netlib` library has its own test suite [2]. It has also been validated by running the Spark test suite and benchmarking suite.

[1] Results for `org.apache.spark.ml.linalg.BLASBenchmark`:
#### JDK8:
```
[info] OpenJDK 64-Bit Server VM 1.8.0_292-b10 on Linux 5.8.0-50-generic
[info] Intel(R) Xeon(R) E-2276G CPU  3.80GHz
[info]
[info] f2jBLAS    = dev.ludovic.netlib.blas.NetlibF2jBLAS
[info] javaBLAS   = dev.ludovic.netlib.blas.Java8BLAS
[info] nativeBLAS = dev.ludovic.netlib.blas.Java8BLAS
[info]
[info] daxpy:                                    Best Time(ms)   Avg Time(ms)   Stdev(ms)    Rate(M/s)   Per Row(ns)   Relative
[info] ------------------------------------------------------------------------------------------------------------------------
[info] f2j                                                 223            232           8        448.0           2.2       1.0X
[info] java                                                221            228           7        453.0           2.2       1.0X
[info]
[info] saxpy:                                    Best Time(ms)   Avg Time(ms)   Stdev(ms)    Rate(M/s)   Per Row(ns)   Relative
[info] ------------------------------------------------------------------------------------------------------------------------
[info] f2j                                                 122            128           4        821.2           1.2       1.0X
[info] java                                                122            128           4        822.3           1.2       1.0X
[info]
[info] ddot:                                     Best Time(ms)   Avg Time(ms)   Stdev(ms)    Rate(M/s)   Per Row(ns)   Relative
[info] ------------------------------------------------------------------------------------------------------------------------
[info] f2j                                                 109            112           2        921.4           1.1       1.0X
[info] java                                                 70             74           3       1423.5           0.7       1.5X
[info]
[info] sdot:                                     Best Time(ms)   Avg Time(ms)   Stdev(ms)    Rate(M/s)   Per Row(ns)   Relative
[info] ------------------------------------------------------------------------------------------------------------------------
[info] f2j                                                  96             98           2       1046.1           1.0       1.0X
[info] java                                                 47             49           2       2121.7           0.5       2.0X
[info]
[info] dscal:                                    Best Time(ms)   Avg Time(ms)   Stdev(ms)    Rate(M/s)   Per Row(ns)   Relative
[info] ------------------------------------------------------------------------------------------------------------------------
[info] f2j                                                 184            195           8        544.3           1.8       1.0X
[info] java                                                185            196           7        539.5           1.9       1.0X
[info]
[info] sscal:                                    Best Time(ms)   Avg Time(ms)   Stdev(ms)    Rate(M/s)   Per Row(ns)   Relative
[info] ------------------------------------------------------------------------------------------------------------------------
[info] f2j                                                  99            104           4       1011.9           1.0       1.0X
[info] java                                                 99            104           4       1010.4           1.0       1.0X
[info]
[info] dspmv[U]:                                 Best Time(ms)   Avg Time(ms)   Stdev(ms)    Rate(M/s)   Per Row(ns)   Relative
[info] ------------------------------------------------------------------------------------------------------------------------
[info] f2j                                                   1              1           0        947.2           1.1       1.0X
[info] java                                                  0              0           0       1584.8           0.6       1.7X
[info]
[info] dspr[U]:                                  Best Time(ms)   Avg Time(ms)   Stdev(ms)    Rate(M/s)   Per Row(ns)   Relative
[info] ------------------------------------------------------------------------------------------------------------------------
[info] f2j                                                   1              1           0        867.4           1.2       1.0X
[info] java                                                  1              1           0        865.0           1.2       1.0X
[info]
[info] dsyr[U]:                                  Best Time(ms)   Avg Time(ms)   Stdev(ms)    Rate(M/s)   Per Row(ns)   Relative
[info] ------------------------------------------------------------------------------------------------------------------------
[info] f2j                                                   1              1           0        485.9           2.1       1.0X
[info] java                                                  1              1           0        486.8           2.1       1.0X
[info]
[info] dgemv[N]:                                 Best Time(ms)   Avg Time(ms)   Stdev(ms)    Rate(M/s)   Per Row(ns)   Relative
[info] ------------------------------------------------------------------------------------------------------------------------
[info] f2j                                                   1              1           0       1843.0           0.5       1.0X
[info] java                                                  0              0           0       2690.6           0.4       1.5X
[info]
[info] dgemv[T]:                                 Best Time(ms)   Avg Time(ms)   Stdev(ms)    Rate(M/s)   Per Row(ns)   Relative
[info] ------------------------------------------------------------------------------------------------------------------------
[info] f2j                                                   1              1           0       1214.7           0.8       1.0X
[info] java                                                  0              0           0       2536.8           0.4       2.1X
[info]
[info] sgemv[N]:                                 Best Time(ms)   Avg Time(ms)   Stdev(ms)    Rate(M/s)   Per Row(ns)   Relative
[info] ------------------------------------------------------------------------------------------------------------------------
[info] f2j                                                   1              1           0       1895.9           0.5       1.0X
[info] java                                                  0              0           0       2961.1           0.3       1.6X
[info]
[info] sgemv[T]:                                 Best Time(ms)   Avg Time(ms)   Stdev(ms)    Rate(M/s)   Per Row(ns)   Relative
[info] ------------------------------------------------------------------------------------------------------------------------
[info] f2j                                                   1              1           0       1223.4           0.8       1.0X
[info] java                                                  0              0           0       3091.4           0.3       2.5X
[info]
[info] dgemm[N,N]:                               Best Time(ms)   Avg Time(ms)   Stdev(ms)    Rate(M/s)   Per Row(ns)   Relative
[info] ------------------------------------------------------------------------------------------------------------------------
[info] f2j                                                 560            575          20       1787.1           0.6       1.0X
[info] java                                                226            232           5       4432.4           0.2       2.5X
[info]
[info] dgemm[N,T]:                               Best Time(ms)   Avg Time(ms)   Stdev(ms)    Rate(M/s)   Per Row(ns)   Relative
[info] ------------------------------------------------------------------------------------------------------------------------
[info] f2j                                                 570            586          23       1755.2           0.6       1.0X
[info] java                                                227            232           4       4410.1           0.2       2.5X
[info]
[info] dgemm[T,N]:                               Best Time(ms)   Avg Time(ms)   Stdev(ms)    Rate(M/s)   Per Row(ns)   Relative
[info] ------------------------------------------------------------------------------------------------------------------------
[info] f2j                                                 863            879          17       1158.4           0.9       1.0X
[info] java                                                227            231           3       4407.9           0.2       3.8X
[info]
[info] dgemm[T,T]:                               Best Time(ms)   Avg Time(ms)   Stdev(ms)    Rate(M/s)   Per Row(ns)   Relative
[info] ------------------------------------------------------------------------------------------------------------------------
[info] f2j                                                1282           1305          23        780.0           1.3       1.0X
[info] java                                                227            232           4       4413.4           0.2       5.7X
[info]
[info] sgemm[N,N]:                               Best Time(ms)   Avg Time(ms)   Stdev(ms)    Rate(M/s)   Per Row(ns)   Relative
[info] ------------------------------------------------------------------------------------------------------------------------
[info] f2j                                                 538            548           8       1858.6           0.5       1.0X
[info] java                                                221            226           3       4521.1           0.2       2.4X
[info]
[info] sgemm[N,T]:                               Best Time(ms)   Avg Time(ms)   Stdev(ms)    Rate(M/s)   Per Row(ns)   Relative
[info] ------------------------------------------------------------------------------------------------------------------------
[info] f2j                                                 549            558          10       1819.9           0.5       1.0X
[info] java                                                222            229           7       4503.5           0.2       2.5X
[info]
[info] sgemm[T,N]:                               Best Time(ms)   Avg Time(ms)   Stdev(ms)    Rate(M/s)   Per Row(ns)   Relative
[info] ------------------------------------------------------------------------------------------------------------------------
[info] f2j                                                 838            852          12       1193.0           0.8       1.0X
[info] java                                                222            229           5       4500.5           0.2       3.8X
[info]
[info] sgemm[T,T]:                               Best Time(ms)   Avg Time(ms)   Stdev(ms)    Rate(M/s)   Per Row(ns)   Relative
[info] ------------------------------------------------------------------------------------------------------------------------
[info] f2j                                                 905            919          18       1104.8           0.9       1.0X
[info] java                                                221            228           5       4521.3           0.2       4.1X
```

#### JDK11:
```
[info] OpenJDK 64-Bit Server VM 11.0.11+9-LTS on Linux 5.8.0-50-generic
[info] Intel(R) Xeon(R) E-2276G CPU  3.80GHz
[info]
[info] f2jBLAS    = dev.ludovic.netlib.blas.NetlibF2jBLAS
[info] javaBLAS   = dev.ludovic.netlib.blas.Java11BLAS
[info] nativeBLAS = dev.ludovic.netlib.blas.Java11BLAS
[info]
[info] daxpy:                                    Best Time(ms)   Avg Time(ms)   Stdev(ms)    Rate(M/s)   Per Row(ns)   Relative
[info] ------------------------------------------------------------------------------------------------------------------------
[info] f2j                                                 195            204          10        512.7           2.0       1.0X
[info] java                                                195            202           7        512.4           2.0       1.0X
[info]
[info] saxpy:                                    Best Time(ms)   Avg Time(ms)   Stdev(ms)    Rate(M/s)   Per Row(ns)   Relative
[info] ------------------------------------------------------------------------------------------------------------------------
[info] f2j                                                 108            113           4        923.3           1.1       1.0X
[info] java                                                102            107           4        984.4           1.0       1.1X
[info]
[info] ddot:                                     Best Time(ms)   Avg Time(ms)   Stdev(ms)    Rate(M/s)   Per Row(ns)   Relative
[info] ------------------------------------------------------------------------------------------------------------------------
[info] f2j                                                 107            110           3        938.1           1.1       1.0X
[info] java                                                 69             72           3       1447.1           0.7       1.5X
[info]
[info] sdot:                                     Best Time(ms)   Avg Time(ms)   Stdev(ms)    Rate(M/s)   Per Row(ns)   Relative
[info] ------------------------------------------------------------------------------------------------------------------------
[info] f2j                                                  96             98           2       1046.5           1.0       1.0X
[info] java                                                 43             45           2       2317.1           0.4       2.2X
[info]
[info] dscal:                                    Best Time(ms)   Avg Time(ms)   Stdev(ms)    Rate(M/s)   Per Row(ns)   Relative
[info] ------------------------------------------------------------------------------------------------------------------------
[info] f2j                                                 155            168           8        644.2           1.6       1.0X
[info] java                                                158            169           8        632.8           1.6       1.0X
[info]
[info] sscal:                                    Best Time(ms)   Avg Time(ms)   Stdev(ms)    Rate(M/s)   Per Row(ns)   Relative
[info] ------------------------------------------------------------------------------------------------------------------------
[info] f2j                                                  85             90           4       1178.1           0.8       1.0X
[info] java                                                 86             90           4       1167.7           0.9       1.0X
[info]
[info] dspmv[U]:                                 Best Time(ms)   Avg Time(ms)   Stdev(ms)    Rate(M/s)   Per Row(ns)   Relative
[info] ------------------------------------------------------------------------------------------------------------------------
[info] f2j                                                   0              0           0       1182.1           0.8       1.0X
[info] java                                                  0              0           0       1432.1           0.7       1.2X
[info]
[info] dspr[U]:                                  Best Time(ms)   Avg Time(ms)   Stdev(ms)    Rate(M/s)   Per Row(ns)   Relative
[info] ------------------------------------------------------------------------------------------------------------------------
[info] f2j                                                   1              1           0        898.7           1.1       1.0X
[info] java                                                  1              1           0        891.5           1.1       1.0X
[info]
[info] dsyr[U]:                                  Best Time(ms)   Avg Time(ms)   Stdev(ms)    Rate(M/s)   Per Row(ns)   Relative
[info] ------------------------------------------------------------------------------------------------------------------------
[info] f2j                                                   1              1           0        495.4           2.0       1.0X
[info] java                                                  1              1           0        495.7           2.0       1.0X
[info]
[info] dgemv[N]:                                 Best Time(ms)   Avg Time(ms)   Stdev(ms)    Rate(M/s)   Per Row(ns)   Relative
[info] ------------------------------------------------------------------------------------------------------------------------
[info] f2j                                                   0              0           0       2271.6           0.4       1.0X
[info] java                                                  0              0           0       3648.1           0.3       1.6X
[info]
[info] dgemv[T]:                                 Best Time(ms)   Avg Time(ms)   Stdev(ms)    Rate(M/s)   Per Row(ns)   Relative
[info] ------------------------------------------------------------------------------------------------------------------------
[info] f2j                                                   1              1           0       1229.3           0.8       1.0X
[info] java                                                  0              0           0       2711.3           0.4       2.2X
[info]
[info] sgemv[N]:                                 Best Time(ms)   Avg Time(ms)   Stdev(ms)    Rate(M/s)   Per Row(ns)   Relative
[info] ------------------------------------------------------------------------------------------------------------------------
[info] f2j                                                   0              0           0       2677.5           0.4       1.0X
[info] java                                                  0              0           0       3288.2           0.3       1.2X
[info]
[info] sgemv[T]:                                 Best Time(ms)   Avg Time(ms)   Stdev(ms)    Rate(M/s)   Per Row(ns)   Relative
[info] ------------------------------------------------------------------------------------------------------------------------
[info] f2j                                                   1              1           0       1233.0           0.8       1.0X
[info] java                                                  0              0           0       2766.3           0.4       2.2X
[info]
[info] dgemm[N,N]:                               Best Time(ms)   Avg Time(ms)   Stdev(ms)    Rate(M/s)   Per Row(ns)   Relative
[info] ------------------------------------------------------------------------------------------------------------------------
[info] f2j                                                 520            536          16       1923.6           0.5       1.0X
[info] java                                                214            221           7       4669.5           0.2       2.4X
[info]
[info] dgemm[N,T]:                               Best Time(ms)   Avg Time(ms)   Stdev(ms)    Rate(M/s)   Per Row(ns)   Relative
[info] ------------------------------------------------------------------------------------------------------------------------
[info] f2j                                                 593            612          17       1686.5           0.6       1.0X
[info] java                                                215            219           3       4643.3           0.2       2.8X
[info]
[info] dgemm[T,N]:                               Best Time(ms)   Avg Time(ms)   Stdev(ms)    Rate(M/s)   Per Row(ns)   Relative
[info] ------------------------------------------------------------------------------------------------------------------------
[info] f2j                                                 853            870          16       1172.8           0.9       1.0X
[info] java                                                215            218           3       4659.7           0.2       4.0X
[info]
[info] dgemm[T,T]:                               Best Time(ms)   Avg Time(ms)   Stdev(ms)    Rate(M/s)   Per Row(ns)   Relative
[info] ------------------------------------------------------------------------------------------------------------------------
[info] f2j                                                1350           1370          23        740.8           1.3       1.0X
[info] java                                                215            219           4       4656.6           0.2       6.3X
[info]
[info] sgemm[N,N]:                               Best Time(ms)   Avg Time(ms)   Stdev(ms)    Rate(M/s)   Per Row(ns)   Relative
[info] ------------------------------------------------------------------------------------------------------------------------
[info] f2j                                                 460            468           6       2173.2           0.5       1.0X
[info] java                                                210            213           2       4752.7           0.2       2.2X
[info]
[info] sgemm[N,T]:                               Best Time(ms)   Avg Time(ms)   Stdev(ms)    Rate(M/s)   Per Row(ns)   Relative
[info] ------------------------------------------------------------------------------------------------------------------------
[info] f2j                                                 535            544           8       1869.3           0.5       1.0X
[info] java                                                210            215           5       4761.8           0.2       2.5X
[info]
[info] sgemm[T,N]:                               Best Time(ms)   Avg Time(ms)   Stdev(ms)    Rate(M/s)   Per Row(ns)   Relative
[info] ------------------------------------------------------------------------------------------------------------------------
[info] f2j                                                 843            853          11       1186.8           0.8       1.0X
[info] java                                                209            214           4       4793.4           0.2       4.0X
[info]
[info] sgemm[T,T]:                               Best Time(ms)   Avg Time(ms)   Stdev(ms)    Rate(M/s)   Per Row(ns)   Relative
[info] ------------------------------------------------------------------------------------------------------------------------
[info] f2j                                                 891            904          15       1122.0           0.9       1.0X
[info] java                                                209            214           4       4777.2           0.2       4.3X
```

#### JDK16:
```
[info] OpenJDK 64-Bit Server VM 16+36 on Linux 5.8.0-50-generic
[info] Intel(R) Xeon(R) E-2276G CPU  3.80GHz
[info]
[info] f2jBLAS    = dev.ludovic.netlib.blas.NetlibF2jBLAS
[info] javaBLAS   = dev.ludovic.netlib.blas.VectorizedBLAS
[info] nativeBLAS = dev.ludovic.netlib.blas.VectorizedBLAS
[info]
[info] daxpy:                                    Best Time(ms)   Avg Time(ms)   Stdev(ms)    Rate(M/s)   Per Row(ns)   Relative
[info] ------------------------------------------------------------------------------------------------------------------------
[info] f2j                                                 194            199           7        515.7           1.9       1.0X
[info] java                                                181            186           3        551.1           1.8       1.1X
[info]
[info] saxpy:                                    Best Time(ms)   Avg Time(ms)   Stdev(ms)    Rate(M/s)   Per Row(ns)   Relative
[info] ------------------------------------------------------------------------------------------------------------------------
[info] f2j                                                 109            115           4        915.0           1.1       1.0X
[info] java                                                 88             92           3       1138.8           0.9       1.2X
[info]
[info] ddot:                                     Best Time(ms)   Avg Time(ms)   Stdev(ms)    Rate(M/s)   Per Row(ns)   Relative
[info] ------------------------------------------------------------------------------------------------------------------------
[info] f2j                                                 108            110           2        922.6           1.1       1.0X
[info] java                                                 54             56           2       1839.2           0.5       2.0X
[info]
[info] sdot:                                     Best Time(ms)   Avg Time(ms)   Stdev(ms)    Rate(M/s)   Per Row(ns)   Relative
[info] ------------------------------------------------------------------------------------------------------------------------
[info] f2j                                                  96             97           2       1046.1           1.0       1.0X
[info] java                                                 29             30           1       3393.4           0.3       3.2X
[info]
[info] dscal:                                    Best Time(ms)   Avg Time(ms)   Stdev(ms)    Rate(M/s)   Per Row(ns)   Relative
[info] ------------------------------------------------------------------------------------------------------------------------
[info] f2j                                                 156            165           5        643.0           1.6       1.0X
[info] java                                                150            159           5        667.1           1.5       1.0X
[info]
[info] sscal:                                    Best Time(ms)   Avg Time(ms)   Stdev(ms)    Rate(M/s)   Per Row(ns)   Relative
[info] ------------------------------------------------------------------------------------------------------------------------
[info] f2j                                                  85             91           6       1171.0           0.9       1.0X
[info] java                                                 75             79           3       1340.6           0.7       1.1X
[info]
[info] dspmv[U]:                                 Best Time(ms)   Avg Time(ms)   Stdev(ms)    Rate(M/s)   Per Row(ns)   Relative
[info] ------------------------------------------------------------------------------------------------------------------------
[info] f2j                                                   1              1           0        917.0           1.1       1.0X
[info] java                                                  0              0           0       8147.2           0.1       8.9X
[info]
[info] dspr[U]:                                  Best Time(ms)   Avg Time(ms)   Stdev(ms)    Rate(M/s)   Per Row(ns)   Relative
[info] ------------------------------------------------------------------------------------------------------------------------
[info] f2j                                                   1              1           0        859.3           1.2       1.0X
[info] java                                                  1              1           0        859.3           1.2       1.0X
[info]
[info] dsyr[U]:                                  Best Time(ms)   Avg Time(ms)   Stdev(ms)    Rate(M/s)   Per Row(ns)   Relative
[info] ------------------------------------------------------------------------------------------------------------------------
[info] f2j                                                   1              1           0        482.1           2.1       1.0X
[info] java                                                  1              1           0        482.6           2.1       1.0X
[info]
[info] dgemv[N]:                                 Best Time(ms)   Avg Time(ms)   Stdev(ms)    Rate(M/s)   Per Row(ns)   Relative
[info] ------------------------------------------------------------------------------------------------------------------------
[info] f2j                                                   0              0           0       2214.2           0.5       1.0X
[info] java                                                  0              0           0       7975.8           0.1       3.6X
[info]
[info] dgemv[T]:                                 Best Time(ms)   Avg Time(ms)   Stdev(ms)    Rate(M/s)   Per Row(ns)   Relative
[info] ------------------------------------------------------------------------------------------------------------------------
[info] f2j                                                   1              1           0       1231.4           0.8       1.0X
[info] java                                                  0              0           0       8680.9           0.1       7.0X
[info]
[info] sgemv[N]:                                 Best Time(ms)   Avg Time(ms)   Stdev(ms)    Rate(M/s)   Per Row(ns)   Relative
[info] ------------------------------------------------------------------------------------------------------------------------
[info] f2j                                                   0              0           0       2684.3           0.4       1.0X
[info] java                                                  0              0           0      18527.1           0.1       6.9X
[info]
[info] sgemv[T]:                                 Best Time(ms)   Avg Time(ms)   Stdev(ms)    Rate(M/s)   Per Row(ns)   Relative
[info] ------------------------------------------------------------------------------------------------------------------------
[info] f2j                                                   1              1           0       1235.4           0.8       1.0X
[info] java                                                  0              0           0      17347.9           0.1      14.0X
[info]
[info] dgemm[N,N]:                               Best Time(ms)   Avg Time(ms)   Stdev(ms)    Rate(M/s)   Per Row(ns)   Relative
[info] ------------------------------------------------------------------------------------------------------------------------
[info] f2j                                                 530            552          18       1887.5           0.5       1.0X
[info] java                                                 58             64           3      17143.9           0.1       9.1X
[info]
[info] dgemm[N,T]:                               Best Time(ms)   Avg Time(ms)   Stdev(ms)    Rate(M/s)   Per Row(ns)   Relative
[info] ------------------------------------------------------------------------------------------------------------------------
[info] f2j                                                 598            620          17       1671.1           0.6       1.0X
[info] java                                                 58             64           3      17196.6           0.1      10.3X
[info]
[info] dgemm[T,N]:                               Best Time(ms)   Avg Time(ms)   Stdev(ms)    Rate(M/s)   Per Row(ns)   Relative
[info] ------------------------------------------------------------------------------------------------------------------------
[info] f2j                                                 834            847          14       1199.4           0.8       1.0X
[info] java                                                 57             63           4      17486.9           0.1      14.6X
[info]
[info] dgemm[T,T]:                               Best Time(ms)   Avg Time(ms)   Stdev(ms)    Rate(M/s)   Per Row(ns)   Relative
[info] ------------------------------------------------------------------------------------------------------------------------
[info] f2j                                                1338           1366          22        747.3           1.3       1.0X
[info] java                                                 58             63           3      17356.6           0.1      23.2X
[info]
[info] sgemm[N,N]:                               Best Time(ms)   Avg Time(ms)   Stdev(ms)    Rate(M/s)   Per Row(ns)   Relative
[info] ------------------------------------------------------------------------------------------------------------------------
[info] f2j                                                 489            501           9       2045.5           0.5       1.0X
[info] java                                                 36             38           2      27721.9           0.0      13.6X
[info]
[info] sgemm[N,T]:                               Best Time(ms)   Avg Time(ms)   Stdev(ms)    Rate(M/s)   Per Row(ns)   Relative
[info] ------------------------------------------------------------------------------------------------------------------------
[info] f2j                                                 478            488           9       2094.0           0.5       1.0X
[info] java                                                 36             38           2      27813.2           0.0      13.3X
[info]
[info] sgemm[T,N]:                               Best Time(ms)   Avg Time(ms)   Stdev(ms)    Rate(M/s)   Per Row(ns)   Relative
[info] ------------------------------------------------------------------------------------------------------------------------
[info] f2j                                                 825            837          10       1211.6           0.8       1.0X
[info] java                                                 35             38           2      28433.1           0.0      23.5X
[info]
[info] sgemm[T,T]:                               Best Time(ms)   Avg Time(ms)   Stdev(ms)    Rate(M/s)   Per Row(ns)   Relative
[info] ------------------------------------------------------------------------------------------------------------------------
[info] f2j                                                 900            918          15       1111.6           0.9       1.0X
[info] java                                                 36             38           2      28073.0           0.0      25.3X
```

[2] https://github.com/luhenry/netlib/tree/master/blas/src/test/java/dev/ludovic/netlib/blas

Closes #32253 from luhenry/master.

Authored-by: Ludovic Henry <git@ludovic.dev>
Signed-off-by: Sean Owen <srowen@gmail.com>
  • Loading branch information
luhenry authored and srowen committed Apr 27, 2021
1 parent 26a8d2f commit 5b77ebb
Show file tree
Hide file tree
Showing 27 changed files with 627 additions and 793 deletions.
3 changes: 3 additions & 0 deletions dev/deps/spark-deps-hadoop-2.7-hive-2.3
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ apacheds-i18n/2.0.0-M15//apacheds-i18n-2.0.0-M15.jar
apacheds-kerberos-codec/2.0.0-M15//apacheds-kerberos-codec-2.0.0-M15.jar
api-asn1-api/1.0.0-M20//api-asn1-api-1.0.0-M20.jar
api-util/1.0.0-M20//api-util-1.0.0-M20.jar
arpack/1.3.2//arpack-1.3.2.jar
arpack_combined_all/0.1//arpack_combined_all-0.1.jar
arrow-format/2.0.0//arrow-format-2.0.0.jar
arrow-memory-core/2.0.0//arrow-memory-core-2.0.0.jar
Expand All @@ -25,6 +26,7 @@ automaton/1.11-8//automaton-1.11-8.jar
avro-ipc/1.10.2//avro-ipc-1.10.2.jar
avro-mapred/1.10.2//avro-mapred-1.10.2.jar
avro/1.10.2//avro-1.10.2.jar
blas/1.3.2//blas-1.3.2.jar
bonecp/0.8.0.RELEASE//bonecp-0.8.0.RELEASE.jar
breeze-macros_2.12/1.0//breeze-macros_2.12-1.0.jar
breeze_2.12/1.0//breeze_2.12-1.0.jar
Expand Down Expand Up @@ -173,6 +175,7 @@ kubernetes-model-policy/5.3.0//kubernetes-model-policy-5.3.0.jar
kubernetes-model-rbac/5.3.0//kubernetes-model-rbac-5.3.0.jar
kubernetes-model-scheduling/5.3.0//kubernetes-model-scheduling-5.3.0.jar
kubernetes-model-storageclass/5.3.0//kubernetes-model-storageclass-5.3.0.jar
lapack/1.3.2//lapack-1.3.2.jar
leveldbjni-all/1.8//leveldbjni-all-1.8.jar
libfb303/0.9.3//libfb303-0.9.3.jar
libthrift/0.12.0//libthrift-0.12.0.jar
Expand Down
3 changes: 3 additions & 0 deletions dev/deps/spark-deps-hadoop-3.2-hive-2.3
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ annotations/17.0.0//annotations-17.0.0.jar
antlr-runtime/3.5.2//antlr-runtime-3.5.2.jar
antlr4-runtime/4.8-1//antlr4-runtime-4.8-1.jar
aopalliance-repackaged/2.6.1//aopalliance-repackaged-2.6.1.jar
arpack/1.3.2//arpack-1.3.2.jar
arpack_combined_all/0.1//arpack_combined_all-0.1.jar
arrow-format/2.0.0//arrow-format-2.0.0.jar
arrow-memory-core/2.0.0//arrow-memory-core-2.0.0.jar
Expand All @@ -20,6 +21,7 @@ automaton/1.11-8//automaton-1.11-8.jar
avro-ipc/1.10.2//avro-ipc-1.10.2.jar
avro-mapred/1.10.2//avro-mapred-1.10.2.jar
avro/1.10.2//avro-1.10.2.jar
blas/1.3.2//blas-1.3.2.jar
bonecp/0.8.0.RELEASE//bonecp-0.8.0.RELEASE.jar
breeze-macros_2.12/1.0//breeze-macros_2.12-1.0.jar
breeze_2.12/1.0//breeze_2.12-1.0.jar
Expand Down Expand Up @@ -144,6 +146,7 @@ kubernetes-model-policy/5.3.0//kubernetes-model-policy-5.3.0.jar
kubernetes-model-rbac/5.3.0//kubernetes-model-rbac-5.3.0.jar
kubernetes-model-scheduling/5.3.0//kubernetes-model-scheduling-5.3.0.jar
kubernetes-model-storageclass/5.3.0//kubernetes-model-storageclass-5.3.0.jar
lapack/1.3.2//lapack-1.3.2.jar
leveldbjni-all/1.8//leveldbjni-all-1.8.jar
libfb303/0.9.3//libfb303-0.9.3.jar
libthrift/0.12.0//libthrift-0.12.0.jar
Expand Down
2 changes: 1 addition & 1 deletion docs/ml-linalg-guide.md
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ WARN BLAS: Failed to load implementation from:com.github.fommil.netlib.NativeSys
WARN BLAS: Failed to load implementation from:com.github.fommil.netlib.NativeRefBLAS
```

If native libraries are not properly configured in the system, the Java implementation (f2jBLAS) will be used as fallback option.
If native libraries are not properly configured in the system, the Java implementation (javaBLAS) will be used as fallback option.

## Spark Configuration

Expand Down
5 changes: 2 additions & 3 deletions graphx/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -60,9 +60,8 @@
<artifactId>guava</artifactId>
</dependency>
<dependency>
<groupId>com.github.fommil.netlib</groupId>
<artifactId>core</artifactId>
<version>${netlib.java.version}</version>
<groupId>dev.ludovic.netlib</groupId>
<artifactId>blas</artifactId>
</dependency>
<dependency>
<groupId>net.sourceforge.f2j</groupId>
Expand Down
31 changes: 15 additions & 16 deletions graphx/src/main/scala/org/apache/spark/graphx/lib/SVDPlusPlus.scala
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,8 @@ package org.apache.spark.graphx.lib

import scala.util.Random

import com.github.fommil.netlib.BLAS.{getInstance => blas}

import org.apache.spark.graphx._
import org.apache.spark.ml.linalg.BLAS
import org.apache.spark.rdd._

/** Implementation of SVD++ algorithm. */
Expand Down Expand Up @@ -102,22 +101,22 @@ object SVDPlusPlus {
val (usr, itm) = (ctx.srcAttr, ctx.dstAttr)
val (p, q) = (usr._1, itm._1)
val rank = p.length
var pred = u + usr._3 + itm._3 + blas.ddot(rank, q, 1, usr._2, 1)
var pred = u + usr._3 + itm._3 + BLAS.nativeBLAS.ddot(rank, q, 1, usr._2, 1)
pred = math.max(pred, conf.minVal)
pred = math.min(pred, conf.maxVal)
val err = ctx.attr - pred
// updateP = (err * q - conf.gamma7 * p) * conf.gamma2
val updateP = q.clone()
blas.dscal(rank, err * conf.gamma2, updateP, 1)
blas.daxpy(rank, -conf.gamma7 * conf.gamma2, p, 1, updateP, 1)
BLAS.nativeBLAS.dscal(rank, err * conf.gamma2, updateP, 1)
BLAS.nativeBLAS.daxpy(rank, -conf.gamma7 * conf.gamma2, p, 1, updateP, 1)
// updateQ = (err * usr._2 - conf.gamma7 * q) * conf.gamma2
val updateQ = usr._2.clone()
blas.dscal(rank, err * conf.gamma2, updateQ, 1)
blas.daxpy(rank, -conf.gamma7 * conf.gamma2, q, 1, updateQ, 1)
BLAS.nativeBLAS.dscal(rank, err * conf.gamma2, updateQ, 1)
BLAS.nativeBLAS.daxpy(rank, -conf.gamma7 * conf.gamma2, q, 1, updateQ, 1)
// updateY = (err * usr._4 * q - conf.gamma7 * itm._2) * conf.gamma2
val updateY = q.clone()
blas.dscal(rank, err * usr._4 * conf.gamma2, updateY, 1)
blas.daxpy(rank, -conf.gamma7 * conf.gamma2, itm._2, 1, updateY, 1)
BLAS.nativeBLAS.dscal(rank, err * usr._4 * conf.gamma2, updateY, 1)
BLAS.nativeBLAS.daxpy(rank, -conf.gamma7 * conf.gamma2, itm._2, 1, updateY, 1)
ctx.sendToSrc((updateP, updateY, (err - conf.gamma6 * usr._3) * conf.gamma1))
ctx.sendToDst((updateQ, updateY, (err - conf.gamma6 * itm._3) * conf.gamma1))
}
Expand All @@ -129,15 +128,15 @@ object SVDPlusPlus {
ctx => ctx.sendToSrc(ctx.dstAttr._2),
(g1, g2) => {
val out = g1.clone()
blas.daxpy(out.length, 1.0, g2, 1, out, 1)
BLAS.nativeBLAS.daxpy(out.length, 1.0, g2, 1, out, 1)
out
})
val gJoinT1 = g.outerJoinVertices(t1) {
(vid: VertexId, vd: (Array[Double], Array[Double], Double, Double),
msg: Option[Array[Double]]) =>
if (msg.isDefined) {
val out = vd._1.clone()
blas.daxpy(out.length, vd._4, msg.get, 1, out, 1)
BLAS.nativeBLAS.daxpy(out.length, vd._4, msg.get, 1, out, 1)
(vd._1, out, vd._3, vd._4)
} else {
vd
Expand All @@ -154,19 +153,19 @@ object SVDPlusPlus {
(g1: (Array[Double], Array[Double], Double), g2: (Array[Double], Array[Double], Double)) =>
{
val out1 = g1._1.clone()
blas.daxpy(out1.length, 1.0, g2._1, 1, out1, 1)
BLAS.nativeBLAS.daxpy(out1.length, 1.0, g2._1, 1, out1, 1)
val out2 = g2._2.clone()
blas.daxpy(out2.length, 1.0, g2._2, 1, out2, 1)
BLAS.nativeBLAS.daxpy(out2.length, 1.0, g2._2, 1, out2, 1)
(out1, out2, g1._3 + g2._3)
})
val gJoinT2 = g.outerJoinVertices(t2) {
(vid: VertexId,
vd: (Array[Double], Array[Double], Double, Double),
msg: Option[(Array[Double], Array[Double], Double)]) => {
val out1 = vd._1.clone()
blas.daxpy(out1.length, 1.0, msg.get._1, 1, out1, 1)
BLAS.nativeBLAS.daxpy(out1.length, 1.0, msg.get._1, 1, out1, 1)
val out2 = vd._2.clone()
blas.daxpy(out2.length, 1.0, msg.get._2, 1, out2, 1)
BLAS.nativeBLAS.daxpy(out2.length, 1.0, msg.get._2, 1, out2, 1)
(out1, out2, vd._3 + msg.get._3, vd._4)
}
}.cache()
Expand All @@ -180,7 +179,7 @@ object SVDPlusPlus {
(ctx: EdgeContext[(Array[Double], Array[Double], Double, Double), Double, Double]): Unit = {
val (usr, itm) = (ctx.srcAttr, ctx.dstAttr)
val (p, q) = (usr._1, itm._1)
var pred = u + usr._3 + itm._3 + blas.ddot(q.length, q, 1, usr._2, 1)
var pred = u + usr._3 + itm._3 + BLAS.nativeBLAS.ddot(q.length, q, 1, usr._2, 1)
pred = math.max(pred, conf.minVal)
pred = math.min(pred, conf.maxVal)
val err = (ctx.attr - pred) * (ctx.attr - pred)
Expand Down
25 changes: 25 additions & 0 deletions licenses-binary/LICENSE-blas.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
MIT License
-----------

Copyright 2020, 2021, Ludovic Henry

Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:

The above copyright notice and this permission notice shall be included in
all copies or substantial portions of the Software.

THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.

Please contact git@ludovic.dev or visit ludovic.dev if you need additional
information or have any questions.
33 changes: 5 additions & 28 deletions mllib-local/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,11 @@
<type>test-jar</type>
<scope>test</scope>
</dependency>

<dependency>
<groupId>dev.ludovic.netlib</groupId>
<artifactId>blas</artifactId>
</dependency>
</dependencies>
<profiles>
<profile>
Expand All @@ -88,34 +93,6 @@
</dependency>
</dependencies>
</profile>
<profile>
<id>jvm-vectorized</id>
<properties>
<extra.source.dir>src/jvm-vectorized/java</extra.source.dir>
</properties>
<build>
<plugins>
<plugin>
<groupId>org.codehaus.mojo</groupId>
<artifactId>build-helper-maven-plugin</artifactId>
<executions>
<execution>
<id>add-vectorized-sources</id>
<phase>generate-sources</phase>
<goals>
<goal>add-source</goal>
</goals>
<configuration>
<sources>
<source>${extra.source.dir}</source>
</sources>
</configuration>
</execution>
</executions>
</plugin>
</plugins>
</build>
</profile>
</profiles>
<build>
<outputDirectory>target/scala-${scala.binary.version}/classes</outputDirectory>
Expand Down
Loading

0 comments on commit 5b77ebb

Please sign in to comment.