Skip to content

Commit

Permalink
[SPARK-33882][ML] Add a vectorized BLAS implementation
Browse files Browse the repository at this point in the history
### What changes were proposed in this pull request?

This patch introduces a VectorizedBLAS class which implements such hardware-accelerated BLAS operations. This feature is hidden behind the "vectorized" profile that you can enable by passing "-Pvectorized" to sbt or maven.

The Vector API has been introduced in JDK 16. Following discussion on the mailing list, this API is introduced transparently and needs to be enabled explicitely.

### Why are the changes needed?

Whenever a native BLAS implementation isn't available on the system, Spark automatically falls back onto a Java implementation. With the recent release of the Vector API in the OpenJDK [1], we can use hardware acceleration for such operations.

This change was also discussed on the mailing list. [2]

### Does this PR introduce _any_ user-facing change?

It introduces a build-time profile called `vectorized`. You can pass it to sbt and mvn with `-Pvectorized`. There is no change to the end-user of Spark and it should only impact Spark developpers. It is also disabled by default.

### How was this patch tested?

It passes `build/sbt mllib-local/test` with and without `-Pvectorized` with JDK 16. This patch also introduces benchmarks for BLAS.

The benchmark results are as follows:

```
[info] daxpy:                                    Best Time(ms)   Avg Time(ms)   Stdev(ms)    Rate(M/s)   Per Row(ns)   Relative
[info] ------------------------------------------------------------------------------------------------------------------------
[info] f2j                                                  37             37           0        271.5           3.7       1.0X
[info] vector                                               24             25           4        416.1           2.4       1.5X
[info]
[info] ddot:                                     Best Time(ms)   Avg Time(ms)   Stdev(ms)    Rate(M/s)   Per Row(ns)   Relative
[info] ------------------------------------------------------------------------------------------------------------------------
[info] f2j                                                  70             70           0        143.2           7.0       1.0X
[info] vector                                               35             35           2        288.7           3.5       2.0X
[info]
[info] sdot:                                     Best Time(ms)   Avg Time(ms)   Stdev(ms)    Rate(M/s)   Per Row(ns)   Relative
[info] ------------------------------------------------------------------------------------------------------------------------
[info] f2j                                                  50             51           1        199.8           5.0       1.0X
[info] vector                                               15             15           0        648.7           1.5       3.2X
[info]
[info] dscal:                                    Best Time(ms)   Avg Time(ms)   Stdev(ms)    Rate(M/s)   Per Row(ns)   Relative
[info] ------------------------------------------------------------------------------------------------------------------------
[info] f2j                                                  34             34           0        295.6           3.4       1.0X
[info] vector                                               19             19           0        531.2           1.9       1.8X
[info]
[info] sscal:                                    Best Time(ms)   Avg Time(ms)   Stdev(ms)    Rate(M/s)   Per Row(ns)   Relative
[info] ------------------------------------------------------------------------------------------------------------------------
[info] f2j                                                  25             25           1        399.0           2.5       1.0X
[info] vector                                                8              9           1       1177.3           0.8       3.0X
[info]
[info] dgemv[N]:                                 Best Time(ms)   Avg Time(ms)   Stdev(ms)    Rate(M/s)   Per Row(ns)   Relative
[info] ------------------------------------------------------------------------------------------------------------------------
[info] f2j                                                  27             27           0          0.0       26651.5       1.0X
[info] vector                                               21             21           0          0.0       20646.3       1.3X
[info]
[info] dgemv[T]:                                 Best Time(ms)   Avg Time(ms)   Stdev(ms)    Rate(M/s)   Per Row(ns)   Relative
[info] ------------------------------------------------------------------------------------------------------------------------
[info] f2j                                                  36             36           0          0.0       35501.4       1.0X
[info] vector                                               22             22           0          0.0       21930.3       1.6X
[info]
[info] sgemv[N]:                                 Best Time(ms)   Avg Time(ms)   Stdev(ms)    Rate(M/s)   Per Row(ns)   Relative
[info] ------------------------------------------------------------------------------------------------------------------------
[info] f2j                                                  20             20           0          0.0       20283.3       1.0X
[info] vector                                                9              9           0          0.1        8657.7       2.3X
[info]
[info] sgemv[T]:                                 Best Time(ms)   Avg Time(ms)   Stdev(ms)    Rate(M/s)   Per Row(ns)   Relative
[info] ------------------------------------------------------------------------------------------------------------------------
[info] f2j                                                  30             30           0          0.0       29845.8       1.0X
[info] vector                                               10             10           1          0.1        9695.4       3.1X
[info]
[info] dgemm[N,N]:                               Best Time(ms)   Avg Time(ms)   Stdev(ms)    Rate(M/s)   Per Row(ns)   Relative
[info] ------------------------------------------------------------------------------------------------------------------------
[info] f2j                                                 182            182           0          0.5        1820.0       1.0X
[info] vector                                              160            160           1          0.6        1597.6       1.1X
[info]
[info] dgemm[N,T]:                               Best Time(ms)   Avg Time(ms)   Stdev(ms)    Rate(M/s)   Per Row(ns)   Relative
[info] ------------------------------------------------------------------------------------------------------------------------
[info] f2j                                                 211            211           1          0.5        2106.2       1.0X
[info] vector                                              156            157           0          0.6        1564.4       1.3X
[info]
[info] dgemm[T,N]:                               Best Time(ms)   Avg Time(ms)   Stdev(ms)    Rate(M/s)   Per Row(ns)   Relative
[info] ------------------------------------------------------------------------------------------------------------------------
[info] f2j                                                 276            276           0          0.4        2757.8       1.0X
[info] vector                                              137            137           0          0.7        1365.1       2.0X
```

/cc srowen xkrogen

[1] https://openjdk.java.net/jeps/338
[2] https://mail-archives.apache.org/mod_mbox/spark-dev/202012.mbox/%3cDM5PR2101MB11106162BB3AF32AD29C6C79B0C69DM5PR2101MB1110.namprd21.prod.outlook.com%3e

Closes #30810 from luhenry/master.

Lead-authored-by: Ludovic Henry <luhenry@microsoft.com>
Co-authored-by: Ludovic Henry <git@ludovic.dev>
Signed-off-by: Sean Owen <srowen@gmail.com>
  • Loading branch information
2 people authored and srowen committed Apr 14, 2021
1 parent de9e8b6 commit 9244066
Show file tree
Hide file tree
Showing 15 changed files with 1,085 additions and 52 deletions.
35 changes: 35 additions & 0 deletions mllib-local/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,13 @@
<scope>test</scope>
</dependency>

<dependency>
<groupId>org.apache.spark</groupId>
<artifactId>spark-core_${scala.binary.version}</artifactId>
<version>${project.version}</version>
<type>test-jar</type>
<scope>test</scope>
</dependency>
</dependencies>
<profiles>
<profile>
Expand All @@ -81,6 +88,34 @@
</dependency>
</dependencies>
</profile>
<profile>
<id>jvm-vectorized</id>
<properties>
<extra.source.dir>src/jvm-vectorized/java</extra.source.dir>
</properties>
<build>
<plugins>
<plugin>
<groupId>org.codehaus.mojo</groupId>
<artifactId>build-helper-maven-plugin</artifactId>
<executions>
<execution>
<id>add-vectorized-sources</id>
<phase>generate-sources</phase>
<goals>
<goal>add-source</goal>
</goals>
<configuration>
<sources>
<source>${extra.source.dir}</source>
</sources>
</configuration>
</execution>
</executions>
</plugin>
</plugins>
</build>
</profile>
</profiles>
<build>
<outputDirectory>target/scala-${scala.binary.version}/classes</outputDirectory>
Expand Down
Loading

0 comments on commit 9244066

Please sign in to comment.