diff --git a/.github/workflows/dev_cron.yml b/.github/workflows/dev_cron.yml index ab2840eb..e66d4f42 100644 --- a/.github/workflows/dev_cron.yml +++ b/.github/workflows/dev_cron.yml @@ -25,19 +25,21 @@ on: - edited - synchronize +permissions: read-all + jobs: process: name: Process runs-on: ubuntu-latest steps: - - uses: actions/checkout@v2 + - uses: actions/checkout@ee0669bd1cc54295c223e0bb666b733df41de1c5 # v2.7.0 - name: Comment Issues link if: | github.event_name == 'pull_request_target' && (github.event.action == 'opened' || github.event.action == 'edited') - uses: actions/github-script@v3 + uses: actions/github-script@ffc2c79a5b2490bd33e0a41c1de74b877714d736 # v3.2.0 with: github-token: ${{ secrets.GITHUB_TOKEN }} script: | @@ -49,7 +51,7 @@ jobs: github.event_name == 'pull_request_target' && (github.event.action == 'opened' || github.event.action == 'edited') - uses: actions/github-script@v3 + uses: actions/github-script@ffc2c79a5b2490bd33e0a41c1de74b877714d736 # v3.2.0 with: github-token: ${{ secrets.GITHUB_TOKEN }} script: | diff --git a/CHANGELOG.md b/CHANGELOG.md index 86b48d45..7feb16f7 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,13 +1,343 @@ # Change log -Generated on 2021-04-29 +Generated on 2021-09-02 + +## Release 1.2.0 + +### Gazelle Plugin + +#### Features +||| +|:---|:---| +|[#394](https://github.com/oap-project/gazelle_plugin/issues/394)|Support ColumnarArrowEvalPython operator | +|[#368](https://github.com/oap-project/gazelle_plugin/issues/368)|Encountered Hadoop version (3.2.1) conflict issue on AWS EMR-6.3.0| +|[#375](https://github.com/oap-project/gazelle_plugin/issues/375)|Implement a series of datetime functions| +|[#183](https://github.com/oap-project/gazelle_plugin/issues/183)|Add Date/Timestamp type support| +|[#362](https://github.com/oap-project/gazelle_plugin/issues/362)|make arrow-unsafe allocator as the default| +|[#343](https://github.com/oap-project/gazelle_plugin/issues/343)|configurable codegen opt level| +|[#333](https://github.com/oap-project/gazelle_plugin/issues/333)|Arrow Data Source: CSV format support fix| +|[#223](https://github.com/oap-project/gazelle_plugin/issues/223)|Add Parquet write support to Arrow data source| +|[#320](https://github.com/oap-project/gazelle_plugin/issues/320)|Add build option to enable unsafe Arrow allocator| +|[#337](https://github.com/oap-project/gazelle_plugin/issues/337)|UDF: Add test case for validating basic row-based udf| +|[#326](https://github.com/oap-project/gazelle_plugin/issues/326)|Update Scala unit test to spark-3.1.1| + +#### Performance +||| +|:---|:---| +|[#400](https://github.com/oap-project/gazelle_plugin/issues/400)|Optimize ColumnarToRow Operator in NSE.| +|[#411](https://github.com/oap-project/gazelle_plugin/issues/411)|enable ccache on C++ code compiling| + +#### Bugs Fixed +||| +|:---|:---| +|[#358](https://github.com/oap-project/gazelle_plugin/issues/358)|Running TPC DS all queries with native-sql-engine for 10 rounds will have performance degradation problems in the last few rounds| +|[#481](https://github.com/oap-project/gazelle_plugin/issues/481)|JVM heap memory leak on memory leak tracker facilities| +|[#436](https://github.com/oap-project/gazelle_plugin/issues/436)|Fix for Arrow Data Source test suite| +|[#317](https://github.com/oap-project/gazelle_plugin/issues/317)|persistent memory cache issue| +|[#382](https://github.com/oap-project/gazelle_plugin/issues/382)|Hadoop version conflict when supporting to use gazelle_plugin on Google Cloud Dataproc| +|[#384](https://github.com/oap-project/gazelle_plugin/issues/384)|ColumnarBatchScanExec reading parquet failed on java.lang.IllegalArgumentException: not all nodes and buffers were consumed| +|[#370](https://github.com/oap-project/gazelle_plugin/issues/370)|Failed to get time zone: NoSuchElementException: None.get| +|[#360](https://github.com/oap-project/gazelle_plugin/issues/360)|Cannot compile master branch.| +|[#341](https://github.com/oap-project/gazelle_plugin/issues/341)|build failed on v2 with -Phadoop-3.2| + +#### PRs +||| +|:---|:---| +|[#489](https://github.com/oap-project/gazelle_plugin/pull/489)|[NSE-481] JVM heap memory leak on memory leak tracker facilities (Arrow Allocator)| +|[#486](https://github.com/oap-project/gazelle_plugin/pull/486)|[NSE-475] restore coalescebatches operator before window| +|[#482](https://github.com/oap-project/gazelle_plugin/pull/482)|[NSE-481] JVM heap memory leak on memory leak tracker facilities| +|[#470](https://github.com/oap-project/gazelle_plugin/pull/470)|[NSE-469] Lazy Read: Iterator objects are not correctly released| +|[#464](https://github.com/oap-project/gazelle_plugin/pull/464)|[NSE-460] fix decimal partial sum in 1.2 branch| +|[#439](https://github.com/oap-project/gazelle_plugin/pull/439)|[NSE-433]Support pre-built Jemalloc| +|[#453](https://github.com/oap-project/gazelle_plugin/pull/453)|[NSE-254] remove arrow-data-source-common from jar with dependency| +|[#452](https://github.com/oap-project/gazelle_plugin/pull/452)|[NSE-254]Fix redundant arrow library issue.| +|[#432](https://github.com/oap-project/gazelle_plugin/pull/432)|[NSE-429] TPC-DS Q14a/b get slowed down within setting spark.oap.sql.columnar.sortmergejoin.lazyread=true| +|[#426](https://github.com/oap-project/gazelle_plugin/pull/426)|[NSE-207] Fix aggregate and refresh UT test script| +|[#442](https://github.com/oap-project/gazelle_plugin/pull/442)|[NSE-254]Issue0410 jar size| +|[#441](https://github.com/oap-project/gazelle_plugin/pull/441)|[NSE-254]Issue0410 jar size| +|[#440](https://github.com/oap-project/gazelle_plugin/pull/440)|[NSE-254]Solve the redundant arrow library issue| +|[#437](https://github.com/oap-project/gazelle_plugin/pull/437)|[NSE-436] Fix for Arrow Data Source test suite| +|[#387](https://github.com/oap-project/gazelle_plugin/pull/387)|[NSE-383] Release SMJ input data immediately after being used| +|[#423](https://github.com/oap-project/gazelle_plugin/pull/423)|[NSE-417] fix sort spill on inplsace sort| +|[#416](https://github.com/oap-project/gazelle_plugin/pull/416)|[NSE-207] fix left/right outer join in SMJ| +|[#422](https://github.com/oap-project/gazelle_plugin/pull/422)|[NSE-421]Disable the wholestagecodegen feature for the ArrowColumnarToRow operator| +|[#369](https://github.com/oap-project/gazelle_plugin/pull/369)|[NSE-417] Sort spill support framework| +|[#401](https://github.com/oap-project/gazelle_plugin/pull/401)|[NSE-400] Optimize ColumnarToRow Operator in NSE.| +|[#413](https://github.com/oap-project/gazelle_plugin/pull/413)|[NSE-411] adding ccache support| +|[#393](https://github.com/oap-project/gazelle_plugin/pull/393)|[NSE-207] fix scala unit tests| +|[#407](https://github.com/oap-project/gazelle_plugin/pull/407)|[NSE-403]Add Dataproc integration section to README| +|[#406](https://github.com/oap-project/gazelle_plugin/pull/406)|[NSE-404]Modify repo name in documents| +|[#402](https://github.com/oap-project/gazelle_plugin/pull/402)|[NSE-368]Update emr-6.3.0 support| +|[#395](https://github.com/oap-project/gazelle_plugin/pull/395)|[NSE-394]Support ColumnarArrowEvalPython operator| +|[#346](https://github.com/oap-project/gazelle_plugin/pull/346)|[NSE-317]fix columnar cache| +|[#392](https://github.com/oap-project/gazelle_plugin/pull/392)|[NSE-382]Support GCP Dataproc 2.0| +|[#388](https://github.com/oap-project/gazelle_plugin/pull/388)|[NSE-382]Fix Hadoop version issue| +|[#385](https://github.com/oap-project/gazelle_plugin/pull/385)|[NSE-384] "Select count(*)" without group by results in error: java.lang.IllegalArgumentException: not all nodes and buffers were consumed| +|[#374](https://github.com/oap-project/gazelle_plugin/pull/374)|[NSE-207] fix left anti join and support filter wo/ project| +|[#376](https://github.com/oap-project/gazelle_plugin/pull/376)|[NSE-375] Implement a series of datetime functions| +|[#373](https://github.com/oap-project/gazelle_plugin/pull/373)|[NSE-183] fix timestamp in native side| +|[#356](https://github.com/oap-project/gazelle_plugin/pull/356)|[NSE-207] fix issues found in scala unit tests| +|[#371](https://github.com/oap-project/gazelle_plugin/pull/371)|[NSE-370] Failed to get time zone: NoSuchElementException: None.get| +|[#347](https://github.com/oap-project/gazelle_plugin/pull/347)|[NSE-183] Add Date/Timestamp type support| +|[#363](https://github.com/oap-project/gazelle_plugin/pull/363)|[NSE-362] use arrow-unsafe allocator by default| +|[#361](https://github.com/oap-project/gazelle_plugin/pull/361)|[NSE-273] Spark shim layer infrastructure| +|[#364](https://github.com/oap-project/gazelle_plugin/pull/364)|[NSE-360] fix ut compile and travis test| +|[#264](https://github.com/oap-project/gazelle_plugin/pull/264)|[NSE-207] fix issues found from join unit tests| +|[#344](https://github.com/oap-project/gazelle_plugin/pull/344)|[NSE-343]allow to config codegen opt level| +|[#342](https://github.com/oap-project/gazelle_plugin/pull/342)|[NSE-341] fix maven build failure| +|[#324](https://github.com/oap-project/gazelle_plugin/pull/324)|[NSE-223] Add Parquet write support to Arrow data source| +|[#321](https://github.com/oap-project/gazelle_plugin/pull/321)|[NSE-320] Add build option to enable unsafe Arrow allocator| +|[#299](https://github.com/oap-project/gazelle_plugin/pull/299)|[NSE-207] fix unsuppored types in aggregate| +|[#338](https://github.com/oap-project/gazelle_plugin/pull/338)|[NSE-337] UDF: Add test case for validating basic row-based udf| +|[#336](https://github.com/oap-project/gazelle_plugin/pull/336)|[NSE-333] Arrow Data Source: CSV format support fix| +|[#327](https://github.com/oap-project/gazelle_plugin/pull/327)|[NSE-326] update scala unit tests to spark-3.1.1| + +### OAP MLlib + +#### Features +||| +|:---|:---| +|[#110](https://github.com/oap-project/oap-mllib/issues/110)|Update isOAPEnabled for Kmeans, PCA & ALS| +|[#108](https://github.com/oap-project/oap-mllib/issues/108)|Update PCA GPU, LiR CPU and Improve JAR packaging and libs loading| +|[#93](https://github.com/oap-project/oap-mllib/issues/93)|[GPU] Add GPU support for PCA| +|[#101](https://github.com/oap-project/oap-mllib/issues/101)|[Release] Add version update scripts and improve scripts for examples| +|[#76](https://github.com/oap-project/oap-mllib/issues/76)|Reorganize Spark version specific code structure| +|[#82](https://github.com/oap-project/oap-mllib/issues/82)|[Tests] Add NaiveBayes test and refactors| + +#### Bugs Fixed +||| +|:---|:---| +|[#119](https://github.com/oap-project/oap-mllib/issues/119)|[SDLe][Klocwork] Security vulnerabilities found by static code scan| +|[#121](https://github.com/oap-project/oap-mllib/issues/121)|Meeting freeing memory issue after the training stage when using Intel-MLlib to run PCA and K-means algorithms.| +|[#122](https://github.com/oap-project/oap-mllib/issues/122)|Cannot run K-means and PCA algorithm with oap-mllib on Google Dataproc| +|[#123](https://github.com/oap-project/oap-mllib/issues/123)|[Core] Improve locality handling for native lib loading| +|[#116](https://github.com/oap-project/oap-mllib/issues/116)|Cannot run ALS algorithm with oap-mllib thanks to the commit "2883d3447d07feb55bf5d4fee8225d74b0b1e2b1"| +|[#114](https://github.com/oap-project/oap-mllib/issues/114)|[Core] Improve native lib loading| +|[#94](https://github.com/oap-project/oap-mllib/issues/94)|Failed to run KMeans workload with oap-mllib in JLSE| +|[#95](https://github.com/oap-project/oap-mllib/issues/95)|Some shared libs are missing in 1.1.1 release| +|[#105](https://github.com/oap-project/oap-mllib/issues/105)|[Core] crash when libfabric version conflict| +|[#98](https://github.com/oap-project/oap-mllib/issues/98)|[SDLe][Klocwork] Security vulnerabilities found by static code scan| +|[#88](https://github.com/oap-project/oap-mllib/issues/88)|[Test] Fix ALS Suite "ALS shuffle cleanup standalone"| +|[#86](https://github.com/oap-project/oap-mllib/issues/86)|[NaiveBayes] Fix isOAPEnabled and add multi-version support| + +#### PRs +||| +|:---|:---| +|[#124](https://github.com/oap-project/oap-mllib/pull/124)|[ML-123][Core] Improve locality handling for native lib loading| +|[#118](https://github.com/oap-project/oap-mllib/pull/118)|[ML-116] use getOneCCLIPPort and fix lib loading| +|[#115](https://github.com/oap-project/oap-mllib/pull/115)|[ML-114] [Core] Improve native lib loading| +|[#113](https://github.com/oap-project/oap-mllib/pull/113)|[ML-110] Update isOAPEnabled for Kmeans, PCA & ALS| +|[#112](https://github.com/oap-project/oap-mllib/pull/112)|[ML-105][Core] Fix crash when libfabric version conflict| +|[#111](https://github.com/oap-project/oap-mllib/pull/111)|[ML-108] Update PCA GPU, LiR CPU and Improve JAR packaging and libs loading| +|[#104](https://github.com/oap-project/oap-mllib/pull/104)|[ML-93][GPU] Add GPU support for PCA| +|[#103](https://github.com/oap-project/oap-mllib/pull/103)|[ML-98] [Release] Clean Service.java code| +|[#102](https://github.com/oap-project/oap-mllib/pull/102)|[ML-101] [Release] Add version update scripts and improve scripts for examples| +|[#90](https://github.com/oap-project/oap-mllib/pull/90)|[ML-88][Test] Fix ALS Suite "ALS shuffle cleanup standalone"| +|[#87](https://github.com/oap-project/oap-mllib/pull/87)|[ML-86][NaiveBayes] Fix isOAPEnabled and add multi-version support| +|[#83](https://github.com/oap-project/oap-mllib/pull/83)|[ML-82] [Tests] Add NaiveBayes test and refactors| +|[#75](https://github.com/oap-project/oap-mllib/pull/75)|[ML-53] [CPU] Add Linear & Ridge Regression| +|[#77](https://github.com/oap-project/oap-mllib/pull/77)|[ML-76] Reorganize multiple Spark version support code structure| +|[#68](https://github.com/oap-project/oap-mllib/pull/68)|[ML-55] [CPU] Add Naive Bayes| +|[#64](https://github.com/oap-project/oap-mllib/pull/64)|[ML-42] [PIP] Misc improvements and refactor code| +|[#62](https://github.com/oap-project/oap-mllib/pull/62)|[ML-30][Coding Style] Add code style rules & scripts for Scala, Java and C++| + +### SQL DS Cache + +#### Features +||| +|:---|:---| +|[#155](https://github.com/oap-project/sql-ds-cache/issues/155)|reorg to support profile based multi spark version| + +#### Bugs Fixed +||| +|:---|:---| +|[#190](https://github.com/oap-project/sql-ds-cache/issues/190)|The function of vmem-cache and guava-cache should not be associated with arrow.| +|[#181](https://github.com/oap-project/sql-ds-cache/issues/181)|[SDLe]Vulnerabilities scanned by Snyk| + +#### PRs +||| +|:---|:---| +|[#182](https://github.com/oap-project/sql-ds-cache/pull/182)|[SQL-DS-CACHE-181][SDLe]Fix Snyk code scan issues| +|[#191](https://github.com/oap-project/sql-ds-cache/pull/191)|[SQL-DS-CACHE-190]put plasma detector in seperate object to avoid unnecessary dependency of arrow| +|[#189](https://github.com/oap-project/sql-ds-cache/pull/189)|[SQL-DS-CACHE-188][POAE7-1253] improvement of fallback from plasma cache to simple cache| +|[#157](https://github.com/oap-project/sql-ds-cache/pull/157)|[SQL-DS-CACHE-155][POAE7-1187]reorg to support profile based multi spark version| + +### PMem Shuffle + +#### Bugs Fixed +||| +|:---|:---| +|[#46](https://github.com/oap-project/pmem-shuffle/issues/46)|Cannot run Terasort with pmem-shuffle of branch-1.2| +|[#43](https://github.com/oap-project/pmem-shuffle/issues/43)|Rpmp cannot be compiled due to the lack of boost header file.| + +#### PRs +||| +|:---|:---| +|[#51](https://github.com/oap-project/pmem-shuffle/pull/51)|[PMEM-SHUFFLE-50] Remove description about download submodules manually since they can be downloaded automatically.| +|[#49](https://github.com/oap-project/pmem-shuffle/pull/49)|[PMEM-SHUFFLE-48] Fix the bug about mapstatus tracking and add more connections for metastore.| +|[#47](https://github.com/oap-project/pmem-shuffle/pull/47)|[PMEM-SHUFFLE-46] Fix the bug that off-heap memory is over used in shuffle reduce stage. | +|[#40](https://github.com/oap-project/pmem-shuffle/pull/40)|[PMEM-SHUFFLE-39] Fix the bug that pmem-shuffle without RPMP fails to pass Terasort benchmark due to latest patch.| +|[#38](https://github.com/oap-project/pmem-shuffle/pull/38)|[PMEM-SHUFFLE-37] Add start-rpmp.sh and stop-rpmp.sh| +|[#33](https://github.com/oap-project/pmem-shuffle/pull/33)|[PMEM-SHUFFLE-28]Add RPMP with HA support and integrate it with Spark3.1.1| +|[#27](https://github.com/oap-project/pmem-shuffle/pull/27)|[PMEM-SHUFFLE] Change artifact name to make it compatible with naming…| + +### Remote Shuffle + +#### Bugs Fixed +||| +|:---|:---| +|[#24](https://github.com/oap-project/remote-shuffle/issues/24)|Enhance executor memory release| + +#### PRs +||| +|:---|:---| +|[#25](https://github.com/oap-project/remote-shuffle/pull/25)|[REMOTE-SHUFFLE-24] Enhance executor memory release| + + +## Release 1.1.1 + +### Native SQL Engine + +#### Features +||| +|:---|:---| +|[#304](https://github.com/oap-project/native-sql-engine/issues/304)|Upgrade to Arrow 4.0.0| +|[#285](https://github.com/oap-project/native-sql-engine/issues/285)|ColumnarWindow: Support Date/Timestamp input in MAX/MIN| +|[#297](https://github.com/oap-project/native-sql-engine/issues/297)|Disable incremental compiler in CI| +|[#245](https://github.com/oap-project/native-sql-engine/issues/245)|Support columnar rdd cache| +|[#276](https://github.com/oap-project/native-sql-engine/issues/276)|Add option to switch Hadoop version| +|[#274](https://github.com/oap-project/native-sql-engine/issues/274)|Comment to trigger tpc-h RAM test| +|[#256](https://github.com/oap-project/native-sql-engine/issues/256)|CI: do not run ram report for each PR| + +#### Bugs Fixed +||| +|:---|:---| +|[#325](https://github.com/oap-project/native-sql-engine/issues/325)|java.util.ConcurrentModificationException: mutation occurred during iteration| +|[#329](https://github.com/oap-project/native-sql-engine/issues/329)|numPartitions are not the same| +|[#318](https://github.com/oap-project/native-sql-engine/issues/318)|fix Spark 311 on data source v2| +|[#311](https://github.com/oap-project/native-sql-engine/issues/311)|Build reports errors| +|[#302](https://github.com/oap-project/native-sql-engine/issues/302)|test on v2 failed due to an exception| +|[#257](https://github.com/oap-project/native-sql-engine/issues/257)|different version of slf4j-log4j| +|[#293](https://github.com/oap-project/native-sql-engine/issues/293)|Fix BHJ loss if key = 0| +|[#248](https://github.com/oap-project/native-sql-engine/issues/248)|arrow dependency must put after arrow installation| + +#### PRs +||| +|:---|:---| +|[#332](https://github.com/oap-project/native-sql-engine/pull/332)|[NSE-325] fix incremental compile issue with 4.5.x scala-maven-plugin| +|[#335](https://github.com/oap-project/native-sql-engine/pull/335)|[NSE-329] fix out partitioning in BHJ and SHJ| +|[#328](https://github.com/oap-project/native-sql-engine/pull/328)|[NSE-318]check schema before reuse exchange| +|[#307](https://github.com/oap-project/native-sql-engine/pull/307)|[NSE-304] Upgrade to Arrow 4.0.0| +|[#312](https://github.com/oap-project/native-sql-engine/pull/312)|[NSE-311] Build reports errors| +|[#272](https://github.com/oap-project/native-sql-engine/pull/272)|[NSE-273] support spark311| +|[#303](https://github.com/oap-project/native-sql-engine/pull/303)|[NSE-302] fix v2 test| +|[#306](https://github.com/oap-project/native-sql-engine/pull/306)|[NSE-304] Upgrade to Arrow 4.0.0: Change basic GHA TPC-H test target …| +|[#286](https://github.com/oap-project/native-sql-engine/pull/286)|[NSE-285] ColumnarWindow: Support Date input in MAX/MIN| +|[#298](https://github.com/oap-project/native-sql-engine/pull/298)|[NSE-297] Disable incremental compiler in GHA CI| +|[#291](https://github.com/oap-project/native-sql-engine/pull/291)|[NSE-257] fix multiple slf4j bindings| +|[#294](https://github.com/oap-project/native-sql-engine/pull/294)|[NSE-293] fix unsafemap with key = '0'| +|[#233](https://github.com/oap-project/native-sql-engine/pull/233)|[NSE-207] fix issues found from aggregate unit tests| +|[#246](https://github.com/oap-project/native-sql-engine/pull/246)|[NSE-245]Adding columnar RDD cache support| +|[#289](https://github.com/oap-project/native-sql-engine/pull/289)|[NSE-206]Update installation guide and configuration guide.| +|[#277](https://github.com/oap-project/native-sql-engine/pull/277)|[NSE-276] Add option to switch Hadoop version| +|[#275](https://github.com/oap-project/native-sql-engine/pull/275)|[NSE-274] Comment to trigger tpc-h RAM test| +|[#271](https://github.com/oap-project/native-sql-engine/pull/271)|[NSE-196] clean up configs in unit tests| +|[#258](https://github.com/oap-project/native-sql-engine/pull/258)|[NSE-257] fix different versions of slf4j-log4j12| +|[#259](https://github.com/oap-project/native-sql-engine/pull/259)|[NSE-248] fix arrow dependency order| +|[#249](https://github.com/oap-project/native-sql-engine/pull/249)|[NSE-241] fix hashagg result length| +|[#255](https://github.com/oap-project/native-sql-engine/pull/255)|[NSE-256] do not run ram report test on each PR| + + +### SQL DS Cache + +#### Features +||| +|:---|:---| +|[#118](https://github.com/oap-project/sql-ds-cache/issues/118)|port to Spark 3.1.1| + +#### Bugs Fixed +||| +|:---|:---| +|[#121](https://github.com/oap-project/sql-ds-cache/issues/121)|OAP Index creation stuck issue| + +#### PRs +||| +|:---|:---| +|[#132](https://github.com/oap-project/sql-ds-cache/pull/132)|Fix SampleBasedStatisticsSuite UnitTest case| +|[#122](https://github.com/oap-project/sql-ds-cache/pull/122)|[ sql-ds-cache-121] Fix Index stuck issues| +|[#119](https://github.com/oap-project/sql-ds-cache/pull/119)|[SQL-DS-CACHE-118][POAE7-1130] port sql-ds-cache to Spark3.1.1| + + +### OAP MLlib + +#### Features +||| +|:---|:---| +|[#26](https://github.com/oap-project/oap-mllib/issues/26)|[PIP] Support Spark 3.0.1 / 3.0.2 and upcoming 3.1.1| + +#### PRs +||| +|:---|:---| +|[#39](https://github.com/oap-project/oap-mllib/pull/39)|[ML-26] Build for different spark version by -Pprofile| + + +### PMem Spill + +#### Features +||| +|:---|:---| +|[#34](https://github.com/oap-project/pmem-spill/issues/34)|Support vanilla spark 3.1.1| + +#### PRs +||| +|:---|:---| +|[#41](https://github.com/oap-project/pmem-spill/pull/41)|[PMEM-SPILL-34][POAE7-1119]Port RDD cache to Spark 3.1.1 as separate module| + + +### PMem Common + +#### Features +||| +|:---|:---| +|[#10](https://github.com/oap-project/pmem-common/issues/10)|add -mclflushopt flag to enable clflushopt for gcc| +|[#8](https://github.com/oap-project/pmem-common/issues/8)|use clflushopt instead of clflush | + +#### PRs +||| +|:---|:---| +|[#11](https://github.com/oap-project/pmem-common/pull/11)|[PMEM-COMMON-10][POAE7-1010]Add -mclflushopt flag to enable clflushop…| +|[#9](https://github.com/oap-project/pmem-common/pull/9)|[PMEM-COMMON-8][POAE7-896]use clflush optimize version for clflush| + + +### PMem Shuffle + +#### Features +||| +|:---|:---| +|[#15](https://github.com/oap-project/pmem-shuffle/issues/15)|Doesn't work with Spark3.1.1| + +#### PRs +||| +|:---|:---| +|[#16](https://github.com/oap-project/pmem-shuffle/pull/16)|[pmem-shuffle-15] Make pmem-shuffle support Spark3.1.1| + + +### Remote Shuffle + +#### Features +||| +|:---|:---| +|[#18](https://github.com/oap-project/remote-shuffle/issues/18)|upgrade to Spark-3.1.1| +|[#11](https://github.com/oap-project/remote-shuffle/issues/11)|Support DAOS Object Async API| + +#### PRs +||| +|:---|:---| +|[#19](https://github.com/oap-project/remote-shuffle/pull/19)|[REMOTE-SHUFFLE-18] upgrade to Spark-3.1.1| +|[#14](https://github.com/oap-project/remote-shuffle/pull/14)|[REMOTE-SHUFFLE-11] Support DAOS Object Async API| + + ## Release 1.1.0 -* [Native SQL Engine](#native-sql-engine) -* [SQL DS Cache](#sql-ds-cache) -* [OAP MLlib](#oap-mllib) -* [PMEM Spill](#pmem-spill) -* [PMEM Shuffle](#pmem-shuffle) -* [Remote Shuffle](#remote-shuffle) ### Native SQL Engine @@ -225,7 +555,7 @@ Generated on 2021-04-29 |[#19](https://github.com/oap-project/oap-mllib/pull/19)|[ML-18] Auto detect KVS port for oneCCL to avoid port conflict| -### PMEM Spill +### PMem Spill #### Bugs Fixed ||| @@ -245,7 +575,7 @@ Generated on 2021-04-29 |[#10](https://github.com/oap-project/pmem-spill/pull/10)|Fixing one pmem path on AppDirect mode may cause the pmem initialization path to be empty Path| -### PMEM Shuffle +### PMem Shuffle #### Features ||| @@ -264,7 +594,7 @@ Generated on 2021-04-29 |[#6](https://github.com/oap-project/pmem-shuffle/pull/6)|[PMEM-SHUFFLE-7] enable fsdax mode in pmem-shuffle| -### Remote-Shuffle +### Remote Shuffle #### Features ||| diff --git a/README.md b/README.md index 10338c18..7b78b501 100644 --- a/README.md +++ b/README.md @@ -43,8 +43,8 @@ following configurations in spark-defaults.conf or Spark submit command line arg Note: For DAOS users, DAOS Hadoop/Java API jars should also be included in the classpath as we leverage DAOS Hadoop filesystem. ``` - spark.executor.extraClassPath $HOME/miniconda2/envs/oapenv/oap_jars/remote-shuffle-.jar - spark.driver.extraClassPath $HOME/miniconda2/envs/oapenv/oap_jars/remote-shuffle-.jar + spark.executor.extraClassPath $HOME/miniconda2/envs/oapenv/oap_jars/shuffle-hadoop-.jar + spark.driver.extraClassPath $HOME/miniconda2/envs/oapenv/oap_jars/shuffle-hadoop-.jar ``` Enable the remote shuffle manager and specify the Hadoop storage system URI holding shuffle data. diff --git a/SECURITY.md b/SECURITY.md new file mode 100644 index 00000000..eb482d90 --- /dev/null +++ b/SECURITY.md @@ -0,0 +1,12 @@ +# Security Policy + +## Report a Vulnerability + +Please report security issues or vulnerabilities to the [Intel® Security Center]. + +For more information on how Intel® works to resolve security issues, see +[Vulnerability Handling Guidelines]. + +[Intel® Security Center]:https://www.intel.com/security + +[Vulnerability Handling Guidelines]:https://www.intel.com/content/www/us/en/security-center/vulnerability-handling-guidelines.html diff --git a/docs/OAP-Developer-Guide.md b/docs/OAP-Developer-Guide.md index bf7b036b..30d9ed97 100644 --- a/docs/OAP-Developer-Guide.md +++ b/docs/OAP-Developer-Guide.md @@ -3,13 +3,13 @@ This document contains the instructions & scripts on installing necessary dependencies and building OAP modules. You can get more detailed information from OAP each module below. -* [SQL Index and Data Source Cache](https://github.com/oap-project/sql-ds-cache/blob/v1.1.0-spark-3.0.0/docs/Developer-Guide.md) -* [PMem Common](https://github.com/oap-project/pmem-common/tree/v1.1.0-spark-3.0.0) -* [PMem Spill](https://github.com/oap-project/pmem-spill/tree/v1.1.0-spark-3.0.0) -* [PMem Shuffle](https://github.com/oap-project/pmem-shuffle/tree/v1.1.0-spark-3.0.0#5-install-dependencies-for-pmem-shuffle) -* [Remote Shuffle](https://github.com/oap-project/remote-shuffle/tree/v1.1.0-spark-3.0.0) -* [OAP MLlib](https://github.com/oap-project/oap-mllib/tree/v1.1.0-spark-3.0.0) -* [Native SQL Engine](https://github.com/oap-project/native-sql-engine/tree/v1.1.0-spark-3.0.0) +* [SQL Index and Data Source Cache](https://github.com/oap-project/sql-ds-cache/blob/v1.2.0/docs/Developer-Guide.md) +* [PMem Common](https://github.com/oap-project/pmem-common/tree/v1.2.0) +* [PMem Spill](https://github.com/oap-project/pmem-spill/tree/v1.2.0) +* [PMem Shuffle](https://github.com/oap-project/pmem-shuffle/tree/v1.2.0#5-install-dependencies-for-pmem-shuffle) +* [Remote Shuffle](https://github.com/oap-project/remote-shuffle/tree/v1.2.0) +* [OAP MLlib](https://github.com/oap-project/oap-mllib/tree/v1.2.0) +* [Gazelle Plugin](https://github.com/oap-project/gazelle_plugin/tree/v1.2.0) ## Building OAP @@ -22,45 +22,42 @@ We provide scripts to help automatically install dependencies required, please c # cd oap-tools # sh dev/install-compile-time-dependencies.sh ``` -*Note*: oap-tools tag version `v1.1.0-spark-3.0.0` corresponds to all OAP modules' tag version `v1.1.0-spark-3.0.0`. +*Note*: oap-tools tag version `v1.2.0` corresponds to all OAP modules' tag version `v1.2.0`. Then the dependencies below will be installed: -* [Cmake](https://help.directadmin.com/item.php?id=494) +* [Cmake](https://cmake.org/install/) * [GCC > 7](https://gcc.gnu.org/wiki/InstallingGCC) * [Memkind](https://github.com/memkind/memkind/tree/v1.10.1) * [Vmemcache](https://github.com/pmem/vmemcache) * [HPNL](https://github.com/Intel-bigdata/HPNL) * [PMDK](https://github.com/pmem/pmdk) * [OneAPI](https://software.intel.com/content/www/us/en/develop/tools/oneapi.html) -* [Arrow](https://github.com/oap-project/arrow/tree/arrow-3.0.0-oap-1.1) +* [Arrow](https://github.com/oap-project/arrow/tree/v4.0.0-oap-1.2.0) * [LLVM](https://llvm.org/) -Run the following command to learn more. - -``` -# sh dev/scripts/prepare_oap_env.sh --help -``` - -Run the following command to automatically install specific dependency such as Maven. - -``` -# sh dev/scripts/prepare_oap_env.sh --prepare_maven -``` - - **Requirements for Shuffle Remote PMem Extension** If enable Shuffle Remote PMem extension with RDMA, you can refer to [PMem Shuffle](https://github.com/oap-project/pmem-shuffle) to configure and validate RDMA in advance. ### Building +#### Building OAP + OAP is built with [Apache Maven](http://maven.apache.org/) and Oracle Java 8. -To build OAP package, run command below then you can find a tarball named `oap-$VERSION-bin-spark-$VERSION.tar.gz` under directory `$OAP_TOOLS_HOME/dev/release-package `. +To build OAP package, run command below then you can find a tarball named `oap-$VERSION-*.tar.gz` under directory `$OAP_TOOLS_HOME/dev/release-package `, which contains all OAP module jars. +Change to `root` user, run + ``` -$ sh $OAP_TOOLS_HOME/dev/compile-oap.sh +# cd oap-tools +# sh dev/compile-oap.sh ``` -Building specified OAP Module, such as `sql-ds-cache`, run: +#### Building OAP specific module + +If you just want to build a specific OAP Module, such as `sql-ds-cache`, change to `root` user, then run: + ``` -$ sh $OAP_TOOLS_HOME/dev/compile-oap.sh --sql-ds-cache +# cd oap-tools +# sh dev/compile-oap.sh --component=sql-ds-cache ``` diff --git a/docs/OAP-Installation-Guide.md b/docs/OAP-Installation-Guide.md index c269b978..a4d6b16f 100644 --- a/docs/OAP-Installation-Guide.md +++ b/docs/OAP-Installation-Guide.md @@ -26,20 +26,19 @@ To test your installation, run the command `conda list` in your terminal window ### Installing OAP Create a Conda environment and install OAP Conda package. + ```bash -$ conda create -n oapenv -y python=3.7 -$ conda activate oapenv -$ conda install -c conda-forge -c intel -y oap=1.1.0 +$ conda create -n oapenv -c conda-forge -c intel -y oap=1.2.0 ``` Once finished steps above, you have completed OAP dependencies installation and OAP building, and will find built OAP jars under `$HOME/miniconda2/envs/oapenv/oap_jars` Dependencies below are required by OAP and all of them are included in OAP Conda package, they will be automatically installed in your cluster when you Conda install OAP. Ensure you have activated environment which you created in the previous steps. -- [Arrow](https://github.com/Intel-bigdata/arrow) +- [Arrow](https://github.com/oap-project/arrow/tree/v4.0.0-oap-1.2.0) - [Plasma](http://arrow.apache.org/blog/2017/08/08/plasma-in-memory-object-store/) -- [Memkind](https://anaconda.org/intel/memkind) -- [Vmemcache](https://anaconda.org/intel/vmemcache) +- [Memkind](https://github.com/memkind/memkind/tree/v1.10.1) +- [Vmemcache](https://github.com/pmem/vmemcache.git) - [HPNL](https://anaconda.org/intel/hpnl) - [PMDK](https://github.com/pmem/pmdk) - [OneAPI](https://software.intel.com/content/www/us/en/develop/tools/oneapi.html) diff --git a/pom.xml b/pom.xml index cb6a8f84..c79b010d 100644 --- a/pom.xml +++ b/pom.xml @@ -6,17 +6,17 @@ com.intel.oap remote-shuffle-parent - 1.2.0 + 2.2.0-SNAPSHOT OAP Remote Shuffle Parent POM pom - 2.12.10 - 2.12 + 2.12.10 + 2.12 1.8 ${java.version} ${java.version} - 3.0.0 + 3.3.3 @@ -214,8 +214,9 @@ org.apache.spark - spark-core_2.12 + spark-core_${scala.binary.version} ${spark.version} + provided junit @@ -225,13 +226,13 @@ org.scalatest - scalatest_2.12 - 3.0.8 + scalatest_${scala.binary.version} + 3.2.3 test org.apache.spark - spark-core_2.12 + spark-core_${scala.binary.version} ${spark.version} tests test-jar diff --git a/shuffle-daos/README.md b/shuffle-daos/README.md index e1a30fb3..6ab42ab9 100644 --- a/shuffle-daos/README.md +++ b/shuffle-daos/README.md @@ -1,6 +1,6 @@ # Remote Shuffle Based on DAOS Object API A remote shuffle plugin based on DAOS Object API. You can find DAOS and DAOS Java Wrapper in https://github.com/daos-stack/daos and https://github.com/daos-stack/daos/tree/master/src/client/java. -Thanks to DAOS, the plugin is espacially good for small shuffle block, such as around 200KB. +Thanks to DAOS, the plugin is especially good for small shuffle block, such as around 200KB. See Shuffle DAOS related documentation in [Readme under project root](../README.md). diff --git a/shuffle-daos/pom.xml b/shuffle-daos/pom.xml index 0626465e..afcd8a04 100644 --- a/shuffle-daos/pom.xml +++ b/shuffle-daos/pom.xml @@ -7,13 +7,14 @@ com.intel.oap remote-shuffle-parent - 1.2.0 + 2.2.0-SNAPSHOT - shuffle-daos + remote-shuffle-daos OAP Remote Shuffle Based on DAOS Object API jar + ${project.artifactId}-${project.version}-with-spark-${spark.version} org.codehaus.mojo @@ -114,6 +115,42 @@ ${java.version} + + maven-shade-plugin + + + shade-netty4 + package + + shade + + + true + + + META-INF/org.apache.hadoop.fs.FileSystem + + + + + io.netty.buffer + io.netty.buffershade4 + + + io.netty.util + io.netty.utilshade4 + + + + + + + + ${artifactId}-${version}-with-spark-${spark.version}-netty4-txf + + + + org.apache.maven.plugins maven-surefire-plugin @@ -200,12 +237,12 @@ org.apache.spark - spark-core_2.12 + spark-core_${scala.binary.version} io.daos daos-java - 1.2.2 + 2.4.1 junit @@ -214,12 +251,12 @@ org.scalatest - scalatest_2.12 + scalatest_${scala.binary.version} test org.apache.spark - spark-core_2.12 + spark-core_${scala.binary.version} tests test-jar test @@ -241,4 +278,10 @@ + + + maven-snapshots + http://oss.sonatype.org/content/repositories/snapshots + + diff --git a/shuffle-daos/src/main/java/org/apache/spark/shuffle/daos/BoundThreadExecutors.java b/shuffle-daos/src/main/java/org/apache/spark/shuffle/daos/BoundThreadExecutors.java index 083a1e9c..e0b8a6c1 100644 --- a/shuffle-daos/src/main/java/org/apache/spark/shuffle/daos/BoundThreadExecutors.java +++ b/shuffle-daos/src/main/java/org/apache/spark/shuffle/daos/BoundThreadExecutors.java @@ -26,7 +26,10 @@ import org.slf4j.Logger; import org.slf4j.LoggerFactory; -import java.util.concurrent.*; +import java.util.concurrent.BlockingQueue; +import java.util.concurrent.Executor; +import java.util.concurrent.LinkedBlockingQueue; +import java.util.concurrent.ThreadFactory; import java.util.concurrent.atomic.AtomicInteger; /** diff --git a/shuffle-daos/src/main/java/org/apache/spark/shuffle/daos/DaosParallelReaderAsync.java b/shuffle-daos/src/main/java/org/apache/spark/shuffle/daos/DaosParallelReaderAsync.java new file mode 100644 index 00000000..d7aecae6 --- /dev/null +++ b/shuffle-daos/src/main/java/org/apache/spark/shuffle/daos/DaosParallelReaderAsync.java @@ -0,0 +1,396 @@ +/* + * (C) Copyright 2018-2021 Intel Corporation. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + * GOVERNMENT LICENSE RIGHTS-OPEN SOURCE SOFTWARE + * The Government's rights to use, modify, reproduce, release, perform, display, + * or disclose this software are subject to the terms of the Apache License as + * provided in Contract No. B609815. + * Any reproduction of computer software, computer software documentation, or + * portions thereof marked with this legend must also reproduce the markings. + */ + +package org.apache.spark.shuffle.daos; + +import io.daos.DaosEventQueue; +import io.daos.obj.DaosObject; +import io.daos.obj.IODataDesc; +import io.daos.obj.IOSimpleDDAsync; +import io.netty.buffer.ByteBuf; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import scala.Tuple2; + +import java.io.IOException; +import java.util.HashMap; +import java.util.HashSet; +import java.util.Iterator; +import java.util.LinkedList; +import java.util.List; +import java.util.Map; +import java.util.Set; + +/** + * Reader for reading content from dkey/akeys without knowing their lengths. + * + * determine end of content by checking returned actual size. + */ +public class DaosParallelReaderAsync extends DaosReaderAsync { + + private LinkedList mapFetchList = new LinkedList<>(); + + private Set descSet = new HashSet<>(); + + private FetchState currentState; + + private long totalInMemSize; + + private static final Logger log = LoggerFactory.getLogger(DaosParallelReaderAsync.class); + + public DaosParallelReaderAsync(DaosObject object, ReaderConfig config) throws IOException { + super(object, config); + } + + @Override + public ByteBuf nextBuf() throws IOException { + nextMap = false; + ByteBuf buf = tryCurrentState(); + if (buf != null) { + return buf; + } + // next entry + buf = tryNextState(); + if (buf != null) { + return buf; + } + readFromDaos(); + return tryNextState(); + } + + private ByteBuf tryNextState() throws IOException { + if (currentState != null) { + nextMap = true; + if (!currentState.dataEntries.isEmpty()) { + throw new IllegalStateException("dataEntries should be empty in current state"); + } + FetchState first = mapFetchList.pollFirst(); + if (first != currentState) { + throw new IllegalStateException("currentState should be the first entry of mapFetchList"); + } + } + currentState = mapFetchList.peekFirst(); + if (currentState != null) { + return tryCurrentState(); + } + return null; + } + + private ByteBuf tryCurrentState() throws IOException { + if (currentState != null) { + return currentState.getBuffer(); + } + return null; + } + + @Override + protected ByteBuf readFromDaos() throws IOException { + try { + return super.readFromDaos(); + } catch (IOException e) { + releaseDescSet(); + throw e; + } + } + + private void releaseDescSet() { + descSet.forEach(desc -> desc.discard()); + } + + @Override + protected Class getIODescClass() { + return IODescWithState.class; + } + + @Override + protected IODescWithState createFetchDataDesc(String reduceId) throws IOException { + return new IODescWithState(reduceId, false, eq.getEqWrapperHdl()); + } + + @Override + protected IOSimpleDDAsync createNextDesc(long sizeLimit) throws IOException { + long remaining = sizeLimit; + IODescWithState desc = null; + // fetch more for existing states + for (FetchState state : mapFetchList) { + if (remaining == 0) { + break; + } + if (desc == null) { + desc = createFetchDataDesc(String.valueOf(state.mapReduceId._2)); + } + long readSize = state.prepareFetch(desc, remaining); + remaining -= readSize; + totalInMemSize += readSize; + if (totalInMemSize > config.getMaxMem()) { + remaining = 0; + break; + } + } + // fetch more + int reduceId = -1; + while (remaining > 0) { + curMapReduceId = null; // forward mapreduce id + nextMapReduceId(); + if (curMapReduceId == null) { + break; + } + if (reduceId > 0 & (curMapReduceId._2 != reduceId)) { // make sure entries under same reduce + throw new IllegalStateException("multiple reduce ids"); + } + reduceId = curMapReduceId._2; + FetchState state = new FetchState(curMapReduceId); + mapFetchList.add(state); + if (desc == null) { + desc = createFetchDataDesc(String.valueOf(reduceId)); + } + long readSize = state.prepareFetch(desc, remaining); + remaining -= readSize; + totalInMemSize += readSize; + if (totalInMemSize > config.getMaxMem()) { + break; + } + } + if (desc != null) { + if (desc.getNbrOfEntries() == 0) { + desc.release(); + return null; + } + descSet.add(desc); + } + return desc; + } + + @Override + public void checkPartitionSize() { + if (lastMapReduceIdForReturn == null) { + return; + } + metrics.incRemoteBlocksFetched(1); + } + + @Override + protected ByteBuf enterNewDesc(IOSimpleDDAsync desc) throws IOException { + if (log.isDebugEnabled()) { + StringBuilder sb = new StringBuilder(); + for (int i = 0; i < desc.getAkeyEntries().size(); i++) { + sb.append(desc.getEntry(i).getFetchedData().readableBytes()).append(","); + } + log.debug("desc: " + desc + "\n returned lengths: " + sb); + } + List list = desc.getAkeyEntries(); + if (list != null && !list.isEmpty()) { // entries could be removed in verifyCompleted + return list.get(0).getDataBuffer(); + } + return null; + } + + @Override + protected void verifyCompleted() throws IOException { + IODescWithState failed = null; + int failedCnt = 0; + for (DaosEventQueue.Attachment attachment : completedList) { + IODescWithState desc = (IODescWithState) attachment; + runningDescSet.remove(attachment); + if (desc.isSucceeded()) { + readyList.add(desc); + desc.updateFetchState(); + continue; + } + failedCnt++; + if (failed == null) { + failed = desc; + } else { + desc.release(); + } + } + if (failedCnt > 0) { + IOException e = new IOException("failed to read " + failedCnt + " IOSimpleDDAsync. Return code is " + + failed.getReturnCode() + ". First failed is " + failed); + releaseDescSet(); + throw e; + } + } + + @Override + public void close(boolean force) { + try { + super.close(force); + if (!(mapFetchList.isEmpty() && descSet.isEmpty())) { + throw new IllegalStateException("not all data consumed"); + } + } finally { + releaseDescSet(); + mapFetchList.clear(); + } + } + + private class FetchState { + private long offset; + private int times; + private Tuple2 mapReduceId; + private long size; + private LinkedList> dataEntries = new LinkedList<>(); + + private FetchState(Tuple2 mapReduceId) { + this.mapReduceId = mapReduceId; + this.size = partSizeMap.get(mapReduceId)._1 + 100; // +100 make less call + } + + private long prepareFetch(IODescWithState desc, long remaining) throws IOException { + if (size == 0L) { + return 0L; + } + times++; + long readSize = times * size; + if (readSize > remaining) { + readSize = remaining; + } + addFetchEntry(desc, mapReduceId._1, offset, readSize); // update offset after fetching + IOSimpleDDAsync.AsyncEntry entry = desc.getEntry(desc.getNbrOfEntries() - 1); + dataEntries.add(new Tuple2<>(desc, entry)); + desc.putState(entry, this); + return readSize; + } + + private boolean updateState(IOSimpleDDAsync.AsyncEntry entry) { + Tuple2 lastTuple = dataEntries.getLast(); + if (entry != lastTuple._2) { + throw new IllegalStateException("entries mismatch"); + } + int actualSize = entry.getActualSize(); + offset += actualSize; + int requestSize = entry.getRequestSize(); + if (requestSize > actualSize) { + size = 0L; // indicate end of akey content + } + if (actualSize == 0) { + totalInMemSize -= entry.getDataBuffer().capacity(); + entry.releaseDataBuffer(); // release mem ASAP + dataEntries.remove(lastTuple); + return true; + } + return false; + } + + private ByteBuf tryCurrentEntry() { + if (currentEntry != null && (!currentEntry.isFetchBufReleased())) { + ByteBuf buf = currentEntry.getFetchedData(); + if (buf.readableBytes() > 0) { + return buf; + } + // release buffer as soon as possible + currentEntry.releaseDataBuffer(); + totalInMemSize -= buf.capacity(); + } + return null; + } + + private ByteBuf getBuffer() throws IOException { + while (true) { + ByteBuf buf = readMore(); + if (buf != null) { + return buf; + } + if (reachEnd()) { + metrics.incRemoteBlocksFetched(1); + break; + } + if (readFromDaos() == null) { + break; + } + } + return null; + } + + private ByteBuf readMore() { + while (true) { + ByteBuf buf = tryCurrentEntry(); + if (buf != null) { + return buf; + } + // remove and release entry + if (currentEntry != null) { + Tuple2 tuple = dataEntries.removeFirst(); + tuple._1.removeState(tuple._2); + currentEntry = null; + } + // get next tuple + Tuple2 nextTuple = dataEntries.peekFirst(); + if (nextTuple != null) { + currentEntry = nextTuple._2; + // update metrics + metrics.incRemoteBytesRead(currentEntry.getFetchedData().readableBytes()); + } else { + break; + } + } + return null; + } + + private boolean reachEnd() { + return size == 0L; + } + } + + private class IODescWithState extends IOSimpleDDAsync { + private Map entryStateMap = new HashMap<>(); + + private IODescWithState(String dkey, boolean updateOrFetch, long eqWrapperHandle) throws IOException { + super(dkey, updateOrFetch, eqWrapperHandle); + } + + private void putState(IOSimpleDDAsync.AsyncEntry entry, FetchState state) { + entryStateMap.put(entry, state); + } + + private void removeState(IOSimpleDDAsync.AsyncEntry entry) { + entry.releaseDataBuffer(); + if (entryStateMap.remove(entry) == null) { + throw new IllegalStateException("failed to remove state from Desc"); + } + tryReleaseState(); + } + + private void updateFetchState() { + Iterator> it = entryStateMap.entrySet().iterator(); + while (it.hasNext()) { + Map.Entry entry = it.next(); + if (entry.getValue().updateState(entry.getKey())) { + entry.getKey().releaseDataBuffer(); + it.remove(); + } + } + tryReleaseState(); + } + + private void tryReleaseState() { + if (entryStateMap.isEmpty()) { + release(); + if (!descSet.remove(this)) { + throw new IllegalStateException("failed to remove desc from descset"); + } + } + } + } +} diff --git a/shuffle-daos/src/main/java/org/apache/spark/shuffle/daos/DaosReader.java b/shuffle-daos/src/main/java/org/apache/spark/shuffle/daos/DaosReader.java index a1a16273..02e93d0a 100644 --- a/shuffle-daos/src/main/java/org/apache/spark/shuffle/daos/DaosReader.java +++ b/shuffle-daos/src/main/java/org/apache/spark/shuffle/daos/DaosReader.java @@ -29,11 +29,9 @@ import org.apache.spark.SparkEnv; import org.apache.spark.shuffle.ShuffleReadMetricsReporter; import org.apache.spark.storage.BlockId; -import org.apache.spark.storage.BlockManagerId; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import scala.Tuple2; -import scala.Tuple3; import java.io.IOException; import java.util.LinkedHashMap; @@ -74,7 +72,7 @@ public interface DaosReader { * @param metrics * @return */ - void prepare(LinkedHashMap, Tuple3> partSizeMap, + void prepare(LinkedHashMap, Tuple2> partSizeMap, long maxBytesInFlight, long maxReqSizeShuffleToMem, ShuffleReadMetricsReporter metrics); /** @@ -82,7 +80,7 @@ void prepare(LinkedHashMap, Tuple3 curMapReduceId(); + Tuple2 curMapReduceId(); /** * find next mapReduce Id @@ -131,7 +129,6 @@ final class ReaderConfig { private long maxMem; private int readBatchSize; private int waitDataTimeMs; - private int waitTimeoutTimes; private boolean fromOtherThread; private SparkConf conf; @@ -157,7 +154,6 @@ private void initialize() { this.maxMem = -1L; this.readBatchSize = (int)conf.get(package$.MODULE$.SHUFFLE_DAOS_READ_BATCH_SIZE()); this.waitDataTimeMs = (int)conf.get(package$.MODULE$.SHUFFLE_DAOS_READ_WAIT_MS()); - this.waitTimeoutTimes = (int)conf.get(package$.MODULE$.SHUFFLE_DAOS_READ_WAIT_DATA_TIMEOUT_TIMES()); this.fromOtherThread = (boolean)conf.get(package$.MODULE$.SHUFFLE_DAOS_READ_FROM_OTHER_THREAD()); if (log.isDebugEnabled()) { log.debug("minReadSize: " + minReadSize); @@ -165,7 +161,6 @@ private void initialize() { log.debug("maxMem: " + maxMem); log.debug("readBatchSize: " + readBatchSize); log.debug("waitDataTimeMs: " + waitDataTimeMs); - log.debug("waitTimeoutTimes: " + waitTimeoutTimes); log.debug("fromOtherThread: " + fromOtherThread); } } @@ -176,7 +171,6 @@ public ReaderConfig copy(long maxBytesInFlight, long maxMem) { rc.minReadSize = minReadSize; rc.readBatchSize = readBatchSize; rc.waitDataTimeMs = waitDataTimeMs; - rc.waitTimeoutTimes = waitTimeoutTimes; rc.fromOtherThread = fromOtherThread; if (maxBytesInFlight < rc.minReadSize) { rc.maxBytesInFlight = minReadSize; @@ -194,10 +188,6 @@ public int getWaitDataTimeMs() { return waitDataTimeMs; } - public int getWaitTimeoutTimes() { - return waitTimeoutTimes; - } - public long getMaxBytesInFlight() { return maxBytesInFlight; } diff --git a/shuffle-daos/src/main/java/org/apache/spark/shuffle/daos/DaosReaderAsync.java b/shuffle-daos/src/main/java/org/apache/spark/shuffle/daos/DaosReaderAsync.java index df0ead60..6403df69 100644 --- a/shuffle-daos/src/main/java/org/apache/spark/shuffle/daos/DaosReaderAsync.java +++ b/shuffle-daos/src/main/java/org/apache/spark/shuffle/daos/DaosReaderAsync.java @@ -29,19 +29,26 @@ import io.daos.obj.IODataDescBase; import io.daos.obj.IOSimpleDDAsync; import io.netty.buffer.ByteBuf; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; import java.io.IOException; -import java.util.*; +import java.util.LinkedHashSet; +import java.util.LinkedList; +import java.util.List; +import java.util.Set; public class DaosReaderAsync extends DaosReaderBase { - private DaosEventQueue eq; + protected DaosEventQueue eq; - private Set runningDescSet = new LinkedHashSet<>(); + protected Set runningDescSet = new LinkedHashSet<>(); - private LinkedList readyList = new LinkedList<>(); + protected LinkedList readyList = new LinkedList<>(); - private List completedList = new LinkedList<>(); + protected List completedList = new LinkedList<>(); + + private static final Logger log = LoggerFactory.getLogger(DaosReaderAsync.class); public DaosReaderAsync(DaosObject object, ReaderConfig config) throws IOException { super(object, config); @@ -49,7 +56,7 @@ public DaosReaderAsync(DaosObject object, ReaderConfig config) throws IOExceptio } @Override - protected IODataDescBase createFetchDataDesc(String reduceId) throws IOException { + protected IOSimpleDDAsync createFetchDataDesc(String reduceId) throws IOException { return object.createAsyncDataDescForFetch(reduceId, eq.getEqWrapperHdl()); } @@ -77,11 +84,20 @@ public ByteBuf nextBuf() throws IOException { return readFromDaos(); } - private ByteBuf enterNewDesc(IOSimpleDDAsync desc) throws IOException { + protected ByteBuf enterNewDesc(IOSimpleDDAsync desc) throws IOException { if (currentDesc != null) { currentDesc.release(); } currentDesc = desc; + + if (log.isDebugEnabled()) { + StringBuilder sb = new StringBuilder(); + for (int i = 0; i < desc.getAkeyEntries().size(); i++) { + sb.append(desc.getEntry(i).getFetchedData().readableBytes()).append(","); + } + log.debug("desc: " + desc + "\n returned lengths: " + sb); + } + return validateLastEntryAndGetBuf(desc.getEntry(entryIdx)); } @@ -96,45 +112,46 @@ private void progress() throws IOException { completedList.clear(); long timeOutMs = config.getWaitDataTimeMs(); long start = System.currentTimeMillis(); - int n = eq.pollCompleted(completedList, runningDescSet.size(), timeOutMs); - while (n == 0) { - long dur = System.currentTimeMillis() - start; - if (dur > timeOutMs) { - throw new TimedOutException("timed out after " + dur); + long dur = 0L; + do { + eq.pollCompleted(completedList, getIODescClass(), runningDescSet, runningDescSet.size(), + timeOutMs - dur); + if (completedList.isEmpty()) { + dur = System.currentTimeMillis() - start; + if (dur > timeOutMs) { + throw new TimedOutException("timed out after " + dur); + } } - n = eq.pollCompleted(completedList, runningDescSet.size(), timeOutMs - dur); - } + } while (completedList.isEmpty()); verifyCompleted(); } - private ByteBuf readFromDaos() throws IOException { + protected Class getIODescClass() { + return IOSimpleDDAsync.class; + } + + protected ByteBuf readFromDaos() throws IOException { if (runningDescSet.isEmpty()) { - DaosEventQueue.Event event = null; - TimedOutException te = null; - try { - event = acquireEvent(); - } catch (TimedOutException e) { - te = e; - } + DaosEventQueue.Event event = acquireEvent(); IOSimpleDDAsync taskDesc = (IOSimpleDDAsync) createNextDesc(config.getMaxBytesInFlight()); - if (taskDesc == null) { - if (event != null) { - event.putBack(); + if (taskDesc != null) { + assert Thread.currentThread().getId() == eq.getThreadId() : "current thread " + Thread.currentThread().getId() + + "(" + Thread.currentThread().getName() + "), is not expected " + eq.getThreadId() + "(" + + eq.getThreadName() + ")"; + + runningDescSet.add(taskDesc); + taskDesc.setEvent(event); + try { + object.fetchAsync(taskDesc); + } catch (IOException e) { + taskDesc.release(); + runningDescSet.remove(taskDesc); + throw e; } + } else { + eq.returnEvent(event); return null; } - if (te != null) { // have data to read, but no event - throw te; - } - runningDescSet.add(taskDesc); - taskDesc.setEvent(event); - try { - object.fetchAsync(taskDesc); - } catch (IOException e) { - taskDesc.release(); - runningDescSet.remove(taskDesc); - throw e; - } } progress(); IOSimpleDDAsync desc = nextDesc(); @@ -146,42 +163,54 @@ private ByteBuf readFromDaos() throws IOException { private DaosEventQueue.Event acquireEvent() throws IOException { completedList.clear(); - DaosEventQueue.Event event = eq.acquireEventBlocking(config.getWaitDataTimeMs(), completedList); + DaosEventQueue.Event event = eq.acquireEventBlocking(config.getWaitDataTimeMs(), completedList, + IOSimpleDDAsync.class, runningDescSet); verifyCompleted(); return event; } - private void verifyCompleted() throws IOException { + protected void verifyCompleted() throws IOException { IOSimpleDDAsync failed = null; int failedCnt = 0; for (DaosEventQueue.Attachment attachment : completedList) { - if (runningDescSet.contains(attachment)) { - IOSimpleDDAsync desc = (IOSimpleDDAsync) attachment; - runningDescSet.remove(attachment); - if (desc.isSucceeded()) { - readyList.add(desc); - continue; - } - failedCnt++; - if (failed == null) { - failed = desc; - } + IOSimpleDDAsync desc = (IOSimpleDDAsync) attachment; + runningDescSet.remove(attachment); + if (desc.isSucceeded()) { + readyList.add(desc); + continue; + } + failedCnt++; + if (failed == null) { + failed = desc; + } else { + desc.release(); } } if (failedCnt > 0) { - throw new IOException("failed to read " + failedCnt + " IOSimpleDDAsync. First failed is " + failed); + IOException e = new IOException("failed to read " + failedCnt + " IOSimpleDDAsync. Return code is " + + failed.getReturnCode() + ". First failed is " + failed); + failed.release(); + throw e; } } @Override public void close(boolean force) { - readyList.forEach(desc -> desc.release()); - runningDescSet.forEach(desc -> desc.release()); - if (!(readyList.isEmpty() && runningDescSet.isEmpty())) { + IllegalStateException e = null; + if (!(readyList.isEmpty() & runningDescSet.isEmpty())) { StringBuilder sb = new StringBuilder(); sb.append(readyList.isEmpty() ? "" : "not all data consumed. "); - sb.append(runningDescSet.isEmpty() ? "" : "some data is on flight"); - throw new IllegalStateException(sb.toString()); + sb.append(runningDescSet.isEmpty() ? "" : "some data is on flight "); + e = new IllegalStateException(sb.toString()); + } + readyList.forEach(desc -> desc.release()); + runningDescSet.forEach(desc -> desc.discard()); // to be released when poll + if (currentDesc != null) { + currentDesc.release(); + currentDesc = null; + } + if (e != null) { + throw e; } } } diff --git a/shuffle-daos/src/main/java/org/apache/spark/shuffle/daos/DaosReaderBase.java b/shuffle-daos/src/main/java/org/apache/spark/shuffle/daos/DaosReaderBase.java index 5232b205..90f88b84 100644 --- a/shuffle-daos/src/main/java/org/apache/spark/shuffle/daos/DaosReaderBase.java +++ b/shuffle-daos/src/main/java/org/apache/spark/shuffle/daos/DaosReaderBase.java @@ -5,9 +5,7 @@ import io.netty.buffer.ByteBuf; import org.apache.spark.shuffle.ShuffleReadMetricsReporter; import org.apache.spark.storage.BlockId; -import org.apache.spark.storage.BlockManagerId; import scala.Tuple2; -import scala.Tuple3; import java.io.IOException; import java.util.Iterator; @@ -22,18 +20,18 @@ public abstract class DaosReaderBase implements DaosReader { protected ReaderConfig config; - protected LinkedHashMap, Tuple3> partSizeMap; + protected LinkedHashMap, Tuple2> partSizeMap; - protected Iterator> mapIdIt; + protected Iterator> mapIdIt; protected ShuffleReadMetricsReporter metrics; protected long currentPartSize; - protected Tuple2 curMapReduceId; - protected Tuple2 lastMapReduceIdForSubmit; - protected Tuple2 lastMapReduceIdForReturn; - protected int curOffset; + protected Tuple2 curMapReduceId; + protected Tuple2 lastMapReduceIdForSubmit; + protected Tuple2 lastMapReduceIdForReturn; + protected long curOffset; protected boolean nextMap; protected IODataDescBase currentDesc; @@ -68,7 +66,7 @@ public void setReaderMap(Map readerMap) { } @Override - public void prepare(LinkedHashMap, Tuple3> partSizeMap, + public void prepare(LinkedHashMap, Tuple2> partSizeMap, long maxBytesInFlight, long maxReqSizeShuffleToMem, ShuffleReadMetricsReporter metrics) { this.partSizeMap = partSizeMap; this.config = config.copy(maxBytesInFlight, maxReqSizeShuffleToMem); @@ -78,7 +76,7 @@ public void prepare(LinkedHashMap, Tuple3 curMapReduceId() { + public Tuple2 curMapReduceId() { return lastMapReduceIdForSubmit; } @@ -108,9 +106,7 @@ protected ByteBuf tryCurrentDesc() throws IOException { entryIdx++; } entryIdx = 0; - // no need to release desc since all its entries are released in tryCurrentEntry and - // internal buffers are released after object.fetch - // reader.close will release all in case of failure + currentDesc.release(); currentDesc = null; } return null; @@ -151,14 +147,14 @@ protected ByteBuf tryCurrentEntry() { protected IODataDescBase createNextDesc(long sizeLimit) throws IOException { long remaining = sizeLimit; int reduceId = -1; - long mapId; + String mapId; IODataDescBase desc = null; while (remaining > 0) { nextMapReduceId(); if (curMapReduceId == null) { break; } - if (reduceId > 0 && curMapReduceId._2 != reduceId) { // make sure entries under same reduce + if (reduceId > 0 & (curMapReduceId._2 != reduceId)) { // make sure entries under same reduce break; } reduceId = curMapReduceId._2; @@ -176,7 +172,7 @@ protected IODataDescBase createNextDesc(long sizeLimit) throws IOException { if (desc == null) { desc = createFetchDataDesc(String.valueOf(reduceId)); } - addFetchEntry(desc, String.valueOf(mapId), offset, readSize); + addFetchEntry(desc, mapId, offset, readSize); remaining -= readSize; } return desc; diff --git a/shuffle-daos/src/main/java/org/apache/spark/shuffle/daos/DaosReaderSync.java b/shuffle-daos/src/main/java/org/apache/spark/shuffle/daos/DaosReaderSync.java index 0c962d67..d6c528d6 100644 --- a/shuffle-daos/src/main/java/org/apache/spark/shuffle/daos/DaosReaderSync.java +++ b/shuffle-daos/src/main/java/org/apache/spark/shuffle/daos/DaosReaderSync.java @@ -30,16 +30,15 @@ import io.netty.util.internal.ObjectPool; import org.apache.spark.shuffle.ShuffleReadMetricsReporter; import org.apache.spark.storage.BlockId; -import org.apache.spark.storage.BlockManagerId; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import scala.Tuple2; -import scala.Tuple3; import java.io.IOException; import java.util.LinkedHashMap; import java.util.Map; -import java.util.concurrent.*; +import java.util.concurrent.ThreadFactory; +import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicInteger; import java.util.concurrent.locks.Condition; import java.util.concurrent.locks.Lock; @@ -61,7 +60,7 @@ public class DaosReaderSync extends TaskSubmitter implements DaosReader { private boolean fromOtherThread; - private static Logger logger = LoggerFactory.getLogger(DaosReader.class); + private static Logger logger = LoggerFactory.getLogger(DaosReaderSync.class); /** * construct DaosReader with object and dedicated read executors. @@ -143,14 +142,14 @@ public void checkTotalPartitions() throws IOException { } @Override - public void prepare(LinkedHashMap, Tuple3> partSizeMap, + public void prepare(LinkedHashMap, Tuple2> partSizeMap, long maxBytesInFlight, long maxReqSizeShuffleToMem, ShuffleReadMetricsReporter metrics) { reader.prepare(partSizeMap, maxBytesInFlight, maxReqSizeShuffleToMem, metrics); } @Override - public Tuple2 curMapReduceId() { + public Tuple2 curMapReduceId() { return reader.lastMapReduceIdForSubmit; } @@ -338,25 +337,22 @@ private IODataDescBase tryGetFromOtherThread() throws InterruptedException, IOEx private IODataDescBase waitForValidFromOtherThread() throws InterruptedException, IOException { IODataDescBase desc; - while (true) { - long start = System.nanoTime(); - boolean timeout = waitForCondition(config.getWaitDataTimeMs()); - metrics.incFetchWaitTime(TimeUnit.NANOSECONDS.toMillis(System.nanoTime() - start)); - if (timeout) { - exceedWaitTimes++; - if (logger.isDebugEnabled()) { - logger.debug("exceed wait: {}ms, times: {}", config.getWaitDataTimeMs(), exceedWaitTimes); - } - if (exceedWaitTimes >= config.getWaitTimeoutTimes()) { - return null; - } - } - // get some results after wait - desc = tryGetValidCompleted(); - if (desc != null) { - return desc; + long start = System.nanoTime(); + boolean timeout = waitForCondition(config.getWaitDataTimeMs()); + metrics.incFetchWaitTime(TimeUnit.NANOSECONDS.toMillis(System.nanoTime() - start)); + if (timeout) { + exceedWaitTimes++; + if (logger.isDebugEnabled()) { + logger.debug("exceed wait: {}ms, times: {}", config.getWaitDataTimeMs(), exceedWaitTimes); } + return null; + } + // get some results after wait + desc = tryGetValidCompleted(); + if (desc != null) { + return desc; } + return null; } private void submitMore() throws IOException { @@ -387,7 +383,7 @@ private ByteBuf getBySelfAndSubmitMore(long selfReadLimit) throws IOException { entryIdx = 0; // fetch the next by self IODataDescSync desc = (IODataDescSync) createNextDesc(selfReadLimit); - Tuple2 mapreduceId = reader.lastMapReduceIdForSubmit; + Tuple2 mapreduceId = reader.lastMapReduceIdForSubmit; try { if (fromOtherThread) { submitMore(); @@ -414,10 +410,10 @@ protected void addFetchEntry(IODataDescBase desc, String mapId, long offset, lon if (readSize > Integer.MAX_VALUE) { throw new IllegalArgumentException("readSize should not exceed " + Integer.MAX_VALUE); } - ((IODataDescSync)desc).addEntryForFetch(String.valueOf(mapId), offset, (int)readSize); + ((IODataDescSync)desc).addEntryForFetch(mapId, offset, (int)readSize); } - private ByteBuf getBySelf(IODataDescSync desc, Tuple2 mapreduceId) throws IOException { + private ByteBuf getBySelf(IODataDescSync desc, Tuple2 mapreduceId) throws IOException { // get data by self, no need to release currentDesc if (desc == null) { // reach end return null; @@ -514,8 +510,8 @@ public ReadTaskContext getNext() { return (ReadTaskContext) next; } - public Tuple2 getMapReduceId() { - return (Tuple2) morePara; + public Tuple2 getMapReduceId() { + return (Tuple2) morePara; } } diff --git a/shuffle-daos/src/main/java/org/apache/spark/shuffle/daos/DaosShuffleIO.java b/shuffle-daos/src/main/java/org/apache/spark/shuffle/daos/DaosShuffleIO.java index e2883cb7..49341f32 100644 --- a/shuffle-daos/src/main/java/org/apache/spark/shuffle/daos/DaosShuffleIO.java +++ b/shuffle-daos/src/main/java/org/apache/spark/shuffle/daos/DaosShuffleIO.java @@ -124,6 +124,10 @@ public DaosReader getDaosReader(int shuffleId) throws IOException { return ioManager.getDaosReader(shuffleId); } + public DaosReader getDaosParallelReader(int shuffleId) throws IOException { + return ioManager.getDaosParallelReader(shuffleId); + } + private String getKey(long appId, int shuffleId) { return appId + "" + shuffleId; } diff --git a/shuffle-daos/src/main/java/org/apache/spark/shuffle/daos/DaosShuffleInputStream.java b/shuffle-daos/src/main/java/org/apache/spark/shuffle/daos/DaosShuffleInputStream.java index 0308fa1e..d79a9be5 100644 --- a/shuffle-daos/src/main/java/org/apache/spark/shuffle/daos/DaosShuffleInputStream.java +++ b/shuffle-daos/src/main/java/org/apache/spark/shuffle/daos/DaosShuffleInputStream.java @@ -27,24 +27,21 @@ import io.netty.buffer.ByteBuf; import org.apache.spark.shuffle.ShuffleReadMetricsReporter; import org.apache.spark.storage.BlockId; -import org.apache.spark.storage.BlockManagerId; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; import scala.Tuple2; -import scala.Tuple3; import javax.annotation.concurrent.NotThreadSafe; import java.io.IOException; import java.io.InputStream; -import java.util.*; +import java.util.LinkedHashMap; @NotThreadSafe /** - * A inputstream for reading shuffled data being consisted of multiple map outputs. + * A inputstream for reading shuffled data being consisted of multiple map outputs one by one. * * All records in one specific map output are from same KryoSerializer or Java serializer. To facilitate reading * multiple map outputs in this one inputstream, the read methods return -1 to indicate the completion of current - * map output. Caller should call {@link DaosShuffleInputStream#isCompleted()} to check if all map outputs are read. + * map output. Caller should call {@link DaosShuffleInputStream#isCompleted()} to check if all map outputs + * are read. * * To read more data from next map output, user should call {@link #nextMap()} before read. */ @@ -61,10 +58,8 @@ public class DaosShuffleInputStream extends InputStream { private boolean completed; // ensure the order of partition - // (mapid, reduceid) -> (length, BlockId, BlockManagerId) - private LinkedHashMap, Tuple3> partSizeMap; - - private static final Logger log = LoggerFactory.getLogger(DaosShuffleInputStream.class); + // (mapid, reduceid) -> (length, BlockId) + private LinkedHashMap, Tuple2> partSizeMap; /** * constructor with ordered map outputs info. Check {@link DaosReader.ReaderConfig} for more paras controlling @@ -83,7 +78,7 @@ public class DaosShuffleInputStream extends InputStream { */ public DaosShuffleInputStream( DaosReader reader, - LinkedHashMap, Tuple3> partSizeMap, + LinkedHashMap, Tuple2> partSizeMap, long maxBytesInFlight, long maxReqSizeShuffleToMem, ShuffleReadMetricsReporter metrics) { this.partSizeMap = partSizeMap; @@ -100,16 +95,9 @@ public BlockId getCurBlockId() { return partSizeMap.get(reader.curMapReduceId())._2(); } - public BlockManagerId getCurOriginAddress() { - if (reader.curMapReduceId() == null) { - return null; - } - return partSizeMap.get(reader.curMapReduceId())._3(); - } - - public long getCurMapIndex() { + public String getCurMapIndex() { if (reader.curMapReduceId() == null) { - return -1; + return "-1"; } return reader.curMapReduceId()._1; } diff --git a/shuffle-daos/src/main/java/org/apache/spark/shuffle/daos/DaosShuffleOutputStream.java b/shuffle-daos/src/main/java/org/apache/spark/shuffle/daos/DaosShuffleOutputStream.java index cda78595..ccde575c 100644 --- a/shuffle-daos/src/main/java/org/apache/spark/shuffle/daos/DaosShuffleOutputStream.java +++ b/shuffle-daos/src/main/java/org/apache/spark/shuffle/daos/DaosShuffleOutputStream.java @@ -66,7 +66,7 @@ public void flush() throws IOException { @Override public void close() throws IOException { - daosWriter.flush(partitionId); + daosWriter.flushAll(partitionId); } public long getWrittenBytes() { diff --git a/shuffle-daos/src/main/java/org/apache/spark/shuffle/daos/DaosWriter.java b/shuffle-daos/src/main/java/org/apache/spark/shuffle/daos/DaosWriter.java index 2afa28ba..b1b9f6bf 100644 --- a/shuffle-daos/src/main/java/org/apache/spark/shuffle/daos/DaosWriter.java +++ b/shuffle-daos/src/main/java/org/apache/spark/shuffle/daos/DaosWriter.java @@ -26,6 +26,7 @@ import io.daos.BufferAllocator; import io.daos.obj.DaosObject; import io.daos.obj.IODataDescSync; +import io.daos.obj.IODescUpdAsync; import io.daos.obj.IOSimpleDDAsync; import io.netty.buffer.ByteBuf; import org.apache.spark.SparkConf; @@ -35,6 +36,7 @@ import java.io.IOException; import java.util.ArrayList; +import java.util.Collections; import java.util.List; /** @@ -71,6 +73,30 @@ public interface DaosWriter { */ void write(int partitionId, byte[] array, int offset, int len); + /** + * enable spilling to DAOS + */ + void enableSpill(); + + /** + * increment seq for spilling data. + * + * @param partitionId + */ + void incrementSeq(int partitionId); + + /** + * mark it's merging phase. Records should be written to final akey, mapId. + */ + void startMerging(); + + /** + * is spilling actual happened on specific partition. + * + * @param partitionId + */ + boolean isSpilled(int partitionId); + /** * get length of all partitions. * 0 for empty partition. @@ -81,7 +107,7 @@ public interface DaosWriter { long[] getPartitionLens(int numPartitions); /** - * Flush specific partition to DAOS. + * Flush specific partition to DAOS. Some of non-full buffers may not flushed to DAOS. * * @param partitionId * @throws IOException @@ -89,7 +115,15 @@ public interface DaosWriter { void flush(int partitionId) throws IOException; /** - * Flush all pending writes. + * Flush specific partition to DAOS. All buffers are flushed not matter whether it's spill or not. + * + * @param partitionId + * @throws IOException + */ + void flushAll(int partitionId) throws IOException; + + /** + * Flush all pending writes and wait for completion if it's async op. * * @throws IOException */ @@ -100,6 +134,25 @@ public interface DaosWriter { */ void close(); + /** + * get list of spilled. + * + * @param partitionId + * @return + */ + List getSpillInfo(int partitionId); + + interface ObjectCache { + + T get(); + + T newObject(); + + void put(T object); + + boolean isFull(); + } + /** * Write parameters, including mapId, shuffleId, number of partitions and write config. */ @@ -146,29 +199,61 @@ public WriterConfig getConfig() { } } + class SpillInfo { + private String reduceId; + private String mapId; + private long size; + + public SpillInfo(String partitionId, String cmapId, long roundSize) { + this.reduceId = partitionId; + this.mapId = cmapId; + this.size = roundSize; + } + + public String getReduceId() { + return reduceId; + } + + public String getMapId() { + return mapId; + } + + public long getSize() { + return size; + } + } + /** * Write data to one or multiple netty direct buffers which will be written to DAOS without copy */ class NativeBuffer implements Comparable { private String mapId; + private int seq; + private boolean needSpill; private int partitionId; private String partitionIdKey; private int bufferSize; private int idx = -1; private List bufList = new ArrayList<>(); + private List spillInfos; private long totalSize; - private long roundSize; + private long submittedSize; + private int nbrOfSubmitted; private DaosObject object; private static final Logger LOG = LoggerFactory.getLogger(NativeBuffer.class); - NativeBuffer(DaosObject object, String mapId, int partitionId, int bufferSize) { + NativeBuffer(DaosObject object, String mapId, int partitionId, int bufferSize, boolean needSpill) { this.object = object; this.mapId = mapId; this.partitionId = partitionId; this.partitionIdKey = String.valueOf(partitionId); this.bufferSize = bufferSize; + this.needSpill = needSpill; + if (needSpill) { + spillInfos = new ArrayList<>(); + } } private ByteBuf addNewByteBuf(int len) { @@ -197,7 +282,6 @@ public void write(int b) { buf = addNewByteBuf(1); } buf.writeByte(b); - roundSize += 1; } public void write(byte[] b) { @@ -218,50 +302,148 @@ public void write(byte[] b, int offset, int len) { buf = addNewByteBuf(gap); buf.writeBytes(b, avail, gap); } - roundSize += len; } - public IODataDescSync createUpdateDesc() throws IOException { - if (roundSize == 0 || bufList.isEmpty()) { - return null; + private String currentMapId() { + if (needSpill & seq == 0) { + seq = 1; + } + return seq > 0 ? mapId + "_" + seq : mapId; + } + + /** + * create list of {@link IODataDescSync} each of them has only one akey entry. + * DAOS has a constraint that same akey cannot be referenced twice in one IO. + * + * @return list of {@link IODataDescSync} + * @throws IOException + */ + public List createUpdateDescs() throws IOException { + // make sure each spilled data don't span multiple mapId_s. + return createUpdateDescs(true); + } + + /** + * create list of {@link IODataDescSync} each of them has only one akey entry. + * DAOS has a constraint that same akey cannot be referenced twice in one IO. + * + * @param fullBufferOnly + * if write full buffer only to DAOS? + * @return list of {@link IODataDescSync} + * @throws IOException + */ + public List createUpdateDescs(boolean fullBufferOnly) throws IOException { + int nbrOfBuf = bufList.size(); + if ((nbrOfBuf == 0) | (fullBufferOnly & (nbrOfBuf <= 1))) { + return Collections.emptyList(); } + nbrOfBuf -= fullBufferOnly ? 1 : 0; + + List descList = new ArrayList<>(nbrOfBuf); + String cmapId = currentMapId(); long bufSize = 0; - IODataDescSync desc = object.createDataDescForUpdate(partitionIdKey, IODataDescSync.IodType.ARRAY, 1); - for (ByteBuf buf : bufList) { - desc.addEntryForUpdate(mapId, totalSize, buf); + long offset = needSpill ? 0 : totalSize; + for (int i = 0; i < nbrOfBuf; i++) { + IODataDescSync desc = object.createDataDescForUpdate(partitionIdKey, IODataDescSync.IodType.ARRAY, 1); + ByteBuf buf = bufList.get(i); + desc.addEntryForUpdate(cmapId, offset + bufSize, buf); bufSize += buf.readableBytes(); + descList.add(desc); } - if (roundSize != bufSize) { - throw new IOException("expect update size: " + roundSize + ", actual: " + bufSize); + nbrOfSubmitted = nbrOfBuf; + submittedSize = bufSize; + addSpill(cmapId, bufSize); + return descList; + } + + /** + * create list of {@link IODescUpdAsync} each of them has only one akey entry. + * DAOS has a constraint that same akey cannot be referenced twice in one IO. + * + * @param eqHandle + * @return list of {@link IODescUpdAsync} + * @throws IOException + */ + public List createUpdateDescAsyncs(long eqHandle, ObjectCache cache) + throws IOException { + // make sure each spilled data don't span multiple mapId_s. + return createUpdateDescAsyncs(eqHandle, cache, true); + } + + /** + * create list of {@link IOSimpleDDAsync} each of them has only one akey entry. + * DAOS has a constraint that same akey cannot be referenced twice in one IO. + * + * @param eqHandle + * @param fullBufferOnly + * if write full buffer only to DAOS? + * @return list of {@link IOSimpleDDAsync} + * @throws IOException + */ + public List createUpdateDescAsyncs(long eqHandle, ObjectCache cache, + boolean fullBufferOnly) throws IOException { + int nbrOfBuf = bufList.size(); + if ((nbrOfBuf == 0) | (fullBufferOnly & (nbrOfBuf <= 1))) { + return Collections.emptyList(); } - return desc; - } + nbrOfBuf -= fullBufferOnly ? 1 : 0; - public IOSimpleDDAsync createUpdateDescAsync(long eqHandle) throws IOException { - if (roundSize == 0 || bufList.isEmpty()) { - return null; - } + List descList = new ArrayList<>(nbrOfBuf); + String cmapId = currentMapId(); long bufSize = 0; - IOSimpleDDAsync desc = object.createAsyncDataDescForUpdate(partitionIdKey, eqHandle); - for (ByteBuf buf : bufList) { - desc.addEntryForUpdate(mapId, totalSize, buf); + long offset = needSpill ? 0 : totalSize; + for (int i = 0; i < nbrOfBuf; i++) { + IODescUpdAsync desc; + ByteBuf buf = bufList.get(i); + if (!cache.isFull()) { + desc = cache.get(); + desc.reuse(); + desc.setDkey(partitionIdKey); + desc.setAkey(cmapId); + desc.setOffset(offset + bufSize); + desc.setDataBuffer(buf); + } else { + desc = new IODescUpdAsync(partitionIdKey, cmapId, offset + bufSize, buf); + } bufSize += buf.readableBytes(); + descList.add(desc); } - if (roundSize != bufSize) { - throw new IOException("expect update size: " + roundSize + ", actual: " + bufSize); + + nbrOfSubmitted = nbrOfBuf; + submittedSize = bufSize; + addSpill(cmapId, bufSize); + return descList; + } + + private void addSpill(String cmapId, long roundSize) { + if (needSpill) { + LOG.info("reduce ID: " + partitionIdKey + ", map ID: " + cmapId + ", spilling to DAOS, size: " + roundSize); + spillInfos.add(new SpillInfo(partitionIdKey, cmapId, roundSize)); } - return desc; + } + + public List getSpillInfo() { + return spillInfos; } public void reset(boolean release) { if (release) { - bufList.forEach(b -> b.release()); + for (int i = 0; i < nbrOfSubmitted; i++) { + bufList.get(i).release(); + } } // release==false, buffers will be released when tasks are executed and consumed + int nbrOfBufs = bufList.size(); + ByteBuf lastBuf = nbrOfBufs > 0 ? bufList.get(nbrOfBufs - 1) : null; bufList.clear(); idx = -1; - totalSize += roundSize; - roundSize = 0; + if (nbrOfSubmitted < nbrOfBufs) { // add back last buffer + bufList.add(lastBuf); + idx = 0; + } + totalSize += submittedSize; + submittedSize = 0; + nbrOfSubmitted = 0; } @Override @@ -277,13 +459,40 @@ public long getTotalSize() { return totalSize; } - public long getRoundSize() { - return roundSize; + public long getSubmittedSize() { + return submittedSize; + } + + public String getPartitionIdKey() { + return partitionIdKey; } public List getBufList() { return bufList; } + + public boolean isSpilled() { + return !spillInfos.isEmpty(); + } + + /** + * prepare for caching merged records + */ + public void startMerging() { + if (LOG.isDebugEnabled()) { + LOG.debug("start merging for " + partitionId + ". current total size: " + totalSize); + } + reset(true); + totalSize = 0; + this.needSpill = false; + this.seq = 0; + } + + public void incrementSeq() { + if (needSpill) { + seq++; + } + } } /** @@ -295,11 +504,11 @@ class WriterConfig { private boolean warnSmallWrite; private int asyncWriteBatchSize; private long waitTimeMs; - private int timeoutTimes; private long totalInMemSize; private int totalSubmittedLimit; private int threads; private boolean fromOtherThreads; + private int ioDescCaches; private SparkConf conf; private static final Logger logger = LoggerFactory.getLogger(WriterConfig.class); @@ -310,15 +519,15 @@ class WriterConfig { this.conf = SparkEnv.get().conf(); } warnSmallWrite = (boolean) conf.get(package$.MODULE$.SHUFFLE_DAOS_WRITE_WARN_SMALL_SIZE()); - bufferSize = (int) ((long) conf.get(package$.MODULE$.SHUFFLE_DAOS_WRITE_SINGLE_BUFFER_SIZE()) - * 1024 * 1024); + bufferSize = (int) ((long) conf.get(package$.MODULE$.SHUFFLE_DAOS_WRITE_SINGLE_BUFFER_SIZE()) * 1024); + bufferSize += bufferSize * 0.1; // 10% more for metadata overhead and upper layer deviation minSize = (int) ((long)conf.get(package$.MODULE$.SHUFFLE_DAOS_WRITE_MINIMUM_SIZE()) * 1024); asyncWriteBatchSize = (int)conf.get(package$.MODULE$.SHUFFLE_DAOS_ASYNC_WRITE_BATCH_SIZE()); - timeoutTimes = (int)conf.get(package$.MODULE$.SHUFFLE_DAOS_WRITE_WAIT_DATA_TIMEOUT_TIMES()); waitTimeMs = (int)conf.get(package$.MODULE$.SHUFFLE_DAOS_WRITE_WAIT_MS()); totalInMemSize = (long)conf.get(package$.MODULE$.SHUFFLE_DAOS_WRITE_MAX_BYTES_IN_FLIGHT()) * 1024; totalSubmittedLimit = (int)conf.get(package$.MODULE$.SHUFFLE_DAOS_WRITE_SUBMITTED_LIMIT()); threads = (int)conf.get(package$.MODULE$.SHUFFLE_DAOS_WRITE_THREADS()); + ioDescCaches = (int)conf.get(package$.MODULE$.SHUFFLE_DAOS_WRITE_ASYNC_DESC_CACHES()); fromOtherThreads = (boolean)conf .get(package$.MODULE$.SHUFFLE_DAOS_WRITE_IN_OTHER_THREAD()); if (logger.isDebugEnabled()) { @@ -343,10 +552,6 @@ public long getWaitTimeMs() { return waitTimeMs; } - public int getTimeoutTimes() { - return timeoutTimes; - } - public long getTotalInMemSize() { return totalInMemSize; } @@ -363,6 +568,10 @@ public int getThreads() { return threads; } + public int getIoDescCaches() { + return ioDescCaches; + } + public boolean isFromOtherThreads() { return fromOtherThreads; } @@ -374,7 +583,6 @@ public String toString() { ", minSize=" + minSize + ", warnSmallWrite=" + warnSmallWrite + ", waitTimeMs=" + waitTimeMs + - ", timeoutTimes=" + timeoutTimes + ", totalInMemSize=" + totalInMemSize + ", totalSubmittedLimit=" + totalSubmittedLimit + ", threads=" + threads + diff --git a/shuffle-daos/src/main/java/org/apache/spark/shuffle/daos/DaosWriterAsync.java b/shuffle-daos/src/main/java/org/apache/spark/shuffle/daos/DaosWriterAsync.java index 6d55d287..bcaefd7b 100644 --- a/shuffle-daos/src/main/java/org/apache/spark/shuffle/daos/DaosWriterAsync.java +++ b/shuffle-daos/src/main/java/org/apache/spark/shuffle/daos/DaosWriterAsync.java @@ -26,26 +26,34 @@ import io.daos.DaosEventQueue; import io.daos.TimedOutException; import io.daos.obj.DaosObject; -import io.daos.obj.IOSimpleDDAsync; +import io.daos.obj.IODescUpdAsync; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import java.io.IOException; -import java.util.*; +import java.util.Collections; +import java.util.LinkedHashSet; +import java.util.LinkedList; +import java.util.List; +import java.util.Map; +import java.util.Set; public class DaosWriterAsync extends DaosWriterBase { private DaosEventQueue eq; - private Set descSet = new LinkedHashSet<>(); + private Set descSet = new LinkedHashSet<>(); private List completedList = new LinkedList<>(); - private static final Logger log = LoggerFactory.getLogger(DaosWriterAsync.class); + private AsyncDescCache cache; + + private static Logger log = LoggerFactory.getLogger(DaosWriterAsync.class); public DaosWriterAsync(DaosObject object, WriteParam param) throws IOException { super(object, param); eq = DaosEventQueue.getInstance(0); + cache = new AsyncDescCache(param.getConfig().getIoDescCaches()); } @Override @@ -54,73 +62,127 @@ public void flush(int partitionId) throws IOException { if (buffer == null) { return; } - IOSimpleDDAsync desc = buffer.createUpdateDescAsync(eq.getEqWrapperHdl()); - if (desc == null) { - buffer.reset(true); + List descList = buffer.createUpdateDescAsyncs(eq.getEqWrapperHdl(), cache); + flush(buffer, descList); + } + + @Override + public void flushAll(int partitionId) throws IOException { + NativeBuffer buffer = partitionBufArray[partitionId]; + if (buffer == null) { return; } - DaosEventQueue.Event event = acquireEvent(); - descSet.add(desc); - desc.setEvent(event); - try { - object.updateAsync(desc); - buffer.reset(false); - } catch (Exception e) { + List descList = buffer.createUpdateDescAsyncs(eq.getEqWrapperHdl(), cache, false); + flush(buffer, descList); + } + + private void cacheOrRelease(IODescUpdAsync desc) { + if (desc.isReusable()) { + cache.put(desc); + } else { desc.release(); - descSet.remove(desc); - throw e; } - if (descSet.size() >= config.getAsyncWriteBatchSize()) { - flushAll(); + } + + private void flush(NativeBuffer buffer, List descList) throws IOException { + if (!descList.isEmpty()) { + assert Thread.currentThread().getId() == eq.getThreadId() : "current thread " + Thread.currentThread().getId() + + "(" + Thread.currentThread().getName() + "), is not expected " + eq.getThreadId() + "(" + + eq.getThreadName() + ")"; + for (IODescUpdAsync desc : descList) { + DaosEventQueue.Event event = acquireEvent(); + descSet.add(desc); + desc.setEvent(event); + try { + object.updateAsync(desc); + } catch (Exception e) { + cacheOrRelease(desc); + desc.discard(); + descSet.remove(desc); + throw e; + } + } + buffer.reset(false); + + if (descSet.size() >= config.getAsyncWriteBatchSize()) { + pollOnce(descSet.size(), 1L); + } } } @Override public void flushAll() throws IOException { + for (int i = 0; i < partitionBufArray.length; i++) { + NativeBuffer buffer = partitionBufArray[i]; + if (buffer == null) { + continue; + } + List descList = buffer.createUpdateDescAsyncs(eq.getEqWrapperHdl(), cache, false); + flush(buffer, descList); + } + waitCompletion(); + } + + @Override + protected void waitCompletion() throws IOException { int left; try { - while ((left=descSet.size()) > 0) { - completedList.clear(); - int n = eq.pollCompleted(completedList, left, config.getWaitTimeMs()); - if (n == 0) { - throw new TimedOutException("timed out after " + config.getWaitTimeMs()); - } - verifyCompleted(); + long dur; + long start = System.currentTimeMillis(); + while ((left = descSet.size()) > 0 & ((dur = System.currentTimeMillis() - start) < config.getWaitTimeMs())) { + pollOnce(left, config.getWaitTimeMs() - dur); + } + if (!descSet.isEmpty()) { + throw new TimedOutException("timed out after " + (System.currentTimeMillis() - start)); } } catch (IOException e) { throw new IllegalStateException("failed to complete all running updates. ", e); - } finally { - descSet.forEach(desc -> desc.release()); - descSet.clear(); } + super.flushAll(); } - private DaosEventQueue.Event acquireEvent() throws IOException { + private void pollOnce(int nbr, long timeoutMs) throws IOException { completedList.clear(); - DaosEventQueue.Event event = eq.acquireEventBlocking(config.getWaitTimeMs(), completedList); + eq.pollCompleted(completedList, IODescUpdAsync.class, descSet, nbr, timeoutMs); verifyCompleted(); - return event; + } + + private DaosEventQueue.Event acquireEvent() throws IOException { + completedList.clear(); + try { + DaosEventQueue.Event event = eq.acquireEventBlocking(config.getWaitTimeMs(), completedList, + IODescUpdAsync.class, descSet); + verifyCompleted(); + return event; + } catch (IOException e) { + log.error("EQNBR: " + eq.getNbrOfEvents() + ", " + eq.getNbrOfAcquired() + ", " + descSet.size()); + throw e; + } } private void verifyCompleted() throws IOException { - IOSimpleDDAsync failed = null; + IODescUpdAsync failed = null; int failedCnt = 0; for (DaosEventQueue.Attachment attachment : completedList) { - if (descSet.contains(attachment)) { - attachment.release(); - descSet.remove(attachment); - IOSimpleDDAsync desc = (IOSimpleDDAsync) attachment; - if (desc.isSucceeded()) { - continue; - } + descSet.remove(attachment); + IODescUpdAsync desc = (IODescUpdAsync) attachment; + if (!desc.isSucceeded()) { failedCnt++; if (failed == null) { - failed = desc; + failed = desc; // release after toString so that akey info is captured + continue; } } + if (log.isDebugEnabled()) { + log.debug("written desc: " + desc); + } + cacheOrRelease(desc); } if (failedCnt > 0) { - throw new IOException("failed to write " + failedCnt + " IOSimpleDDAsync. First failed is " + failed); + IOException e = new IOException("failed to write " + failedCnt + " IOSimpleDDAsync. Return code is " + + failed.getReturnCode() + ". First failed is " + failed); + cacheOrRelease(failed); + throw e; } } @@ -135,10 +197,85 @@ public void close() { completedList.clear(); completedList = null; } + + if (descSet.isEmpty()) { // all descs polled + cache.release(); + } else { + descSet.forEach(d -> d.discard()); // to be released when poll + cache.release(descSet); + descSet.clear(); + } + super.close(); } public void setWriterMap(Map writerMap) { writerMap.put(this, 0); this.writerMap = writerMap; } + + static class AsyncDescCache implements ObjectCache { + private int idx; + private int total; + private IODescUpdAsync[] array; + + AsyncDescCache(int maxNbr) { + this.array = new IODescUpdAsync[maxNbr]; + } + + @Override + public IODescUpdAsync get() { + if (idx < total) { + return array[idx++]; + } + if (idx < array.length) { + array[idx] = newObject(); + total++; + return array[idx++]; + } + throw new IllegalStateException("cache is full, " + total); + } + + @Override + public IODescUpdAsync newObject() { + return new IODescUpdAsync(32); + } + + @Override + public void put(IODescUpdAsync desc) { + if (idx <= 0) { + throw new IllegalStateException("more than actual number of IODescUpdAsyncs put back"); + } + if (!desc.isDiscarded()) { + desc.releaseDataBuffer(); + } else { + desc.release(); + desc = newObject(); + } + array[--idx] = desc; + } + + @Override + public boolean isFull() { + return idx >= array.length; + } + + public void release() { + release(Collections.emptySet()); + } + + private void release(Set filterSet) { + for (int i = 0; i < Math.min(total, array.length); i++) { + IODescUpdAsync desc = array[i]; + if (desc != null && !filterSet.contains(desc)) { + desc.release(); + } + } + array = null; + idx = 0; + } + + protected int getIdx() { + return idx; + } + } } diff --git a/shuffle-daos/src/main/java/org/apache/spark/shuffle/daos/DaosWriterBase.java b/shuffle-daos/src/main/java/org/apache/spark/shuffle/daos/DaosWriterBase.java index 1d3e350d..bfe419c0 100644 --- a/shuffle-daos/src/main/java/org/apache/spark/shuffle/daos/DaosWriterBase.java +++ b/shuffle-daos/src/main/java/org/apache/spark/shuffle/daos/DaosWriterBase.java @@ -28,6 +28,7 @@ import org.slf4j.LoggerFactory; import java.io.IOException; +import java.util.List; import java.util.Map; public abstract class DaosWriterBase implements DaosWriter { @@ -36,6 +37,8 @@ public abstract class DaosWriterBase implements DaosWriter { protected String mapId; + protected boolean needSpill; + protected WriteParam param; protected WriterConfig config; @@ -57,7 +60,7 @@ protected DaosWriterBase(DaosObject object, WriteParam param) { protected NativeBuffer getNativeBuffer(int partitionId) { NativeBuffer buffer = partitionBufArray[partitionId]; if (buffer == null) { - buffer = new NativeBuffer(object, mapId, partitionId, config.getBufferSize()); + buffer = new NativeBuffer(object, mapId, partitionId, config.getBufferSize(), needSpill); partitionBufArray[partitionId] = buffer; } return buffer; @@ -78,18 +81,72 @@ public void write(int partitionId, byte[] array, int offset, int len) { getNativeBuffer(partitionId).write(array, offset, len); } + @Override + public void enableSpill() { + this.needSpill = true; + } + + @Override + public void incrementSeq(int partitionId) { + NativeBuffer buffer = partitionBufArray[partitionId]; + if (buffer != null) { + buffer.incrementSeq(); + } + } + + @Override + public void startMerging() { + if (!needSpill) { + throw new IllegalStateException("startMerging called twice or non-spillable partition"); + } + needSpill = false; + // make sure all pending writes done, like async write + try { + waitCompletion(); + } catch (IOException e) { + throw new IllegalStateException("failed to flush all existing writes", e); + } + + for (NativeBuffer buffer : partitionBufArray) { + if (buffer == null) { + continue; + } + buffer.startMerging(); + } + } + + @Override + public boolean isSpilled(int partitionId) { + NativeBuffer buffer = partitionBufArray[partitionId]; + if (buffer == null) { + return false; + } + return buffer.isSpilled(); + } + + @Override + public List getSpillInfo(int partitionId) { + NativeBuffer buffer = partitionBufArray[partitionId]; + if (buffer == null) { + return null; + } + return buffer.getSpillInfo(); + } + + protected abstract void waitCompletion() throws IOException; + @Override public void flushAll() throws IOException {} @Override public long[] getPartitionLens(int numPartitions) { if (LOG.isDebugEnabled()) { - LOG.debug("partition map size: " + partitionBufArray.length); + LOG.debug("partition map " + mapId +", size: " + partitionBufArray.length); for (int i = 0; i < numPartitions; i++) { NativeBuffer nb = partitionBufArray[i]; if (nb != null) { LOG.debug("id: " + i + ", native buffer: " + nb.getPartitionId() + ", " + - nb.getTotalSize() + ", " + nb.getRoundSize()); + nb.getTotalSize() + ", " + nb.getSubmittedSize()); } } } @@ -98,8 +155,8 @@ public long[] getPartitionLens(int numPartitions) { NativeBuffer nb = partitionBufArray[i]; if (nb != null) { lens[i] = nb.getTotalSize(); - if (nb.getRoundSize() != 0 || !nb.getBufList().isEmpty()) { - throw new IllegalStateException("round size should be 0, " + nb.getRoundSize() + + if (nb.getSubmittedSize() != 0 || !nb.getBufList().isEmpty()) { + throw new IllegalStateException("round size should be 0, " + nb.getSubmittedSize() + ", buflist should be empty, " + nb.getBufList().size()); } @@ -109,4 +166,9 @@ public long[] getPartitionLens(int numPartitions) { } return lens; } + + @Override + public void close() { + partitionBufArray = null; + } } diff --git a/shuffle-daos/src/main/java/org/apache/spark/shuffle/daos/DaosWriterSync.java b/shuffle-daos/src/main/java/org/apache/spark/shuffle/daos/DaosWriterSync.java index b045a0ba..4d905d89 100644 --- a/shuffle-daos/src/main/java/org/apache/spark/shuffle/daos/DaosWriterSync.java +++ b/shuffle-daos/src/main/java/org/apache/spark/shuffle/daos/DaosWriterSync.java @@ -27,13 +27,11 @@ import io.daos.obj.DaosObject; import io.daos.obj.IODataDesc; import io.daos.obj.IODataDescSync; -import io.netty.buffer.ByteBuf; import io.netty.util.internal.ObjectPool; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import java.io.IOException; -import java.util.ArrayList; import java.util.List; import java.util.Map; import java.util.concurrent.ThreadFactory; @@ -105,15 +103,47 @@ public long[] getPartitionLens(int numPartitions) { return iw.getPartitionLens(numPartitions); } + @Override + public void enableSpill() { + iw.enableSpill(); + } + + @Override + public void startMerging() { + iw.startMerging(); + } + + @Override + public void incrementSeq(int partitionId) { + iw.incrementSeq(partitionId); + } + + @Override + public boolean isSpilled(int partitionId) { + return iw.isSpilled(partitionId); + } + + @Override + public List getSpillInfo(int partitionId) { + return iw.getSpillInfo(partitionId); + } + @Override public void flush(int partitionId) throws IOException { iw.flush(partitionId); } @Override - public void flushAll() {} + public void flushAll(int partitionId) throws IOException { + iw.flushAll(); + } + + @Override + public void flushAll() throws IOException { + iw.flushAll(); + } - private void runBySelf(IODataDescSync desc, NativeBuffer buffer) throws IOException { + private void runBySelf(IODataDescSync desc) throws IOException { totalBySelfTimes++; try { object.update(desc); @@ -121,34 +151,29 @@ private void runBySelf(IODataDescSync desc, NativeBuffer buffer) throws IOExcept throw new IOException("failed to write partition of " + desc, e); } finally { desc.release(); - buffer.reset(true); } } - private void submitToOtherThreads(IODataDescSync desc, NativeBuffer buffer) throws IOException { + private void submitToOtherThreads(IODataDescSync desc) throws IOException { // move forward to release write buffers moveForward(); // check if we need to wait submitted tasks to be executed if (goodForSubmit()) { - submitAndReset(desc, buffer); + submit(desc, null); return; } // to wait - int timeoutTimes = 0; try { while (!goodForSubmit()) { boolean timeout = waitForCondition(config.getWaitTimeMs()); moveForward(); if (timeout) { - timeoutTimes++; if (LOG.isDebugEnabled()) { - LOG.debug("wait daos write timeout times: " + timeoutTimes); - } - if (timeoutTimes >= config.getTimeoutTimes()) { - totalTimeoutTimes += timeoutTimes; - runBySelf(desc, buffer); - return; + LOG.debug("wait daos write timed out after " + config.getWaitTimeMs()); } + totalTimeoutTimes++; + runBySelf(desc); + return; } } } catch (InterruptedException e) { @@ -156,30 +181,21 @@ private void submitToOtherThreads(IODataDescSync desc, NativeBuffer buffer) thro Thread.currentThread().interrupt(); throw new IOException("interrupted when wait daos write", e); } - // submit write task after some wait - totalTimeoutTimes += timeoutTimes; - submitAndReset(desc, buffer); + // submit task after some wait + submit(desc, null); } private boolean goodForSubmit() { return totalInMemSize < config.getTotalInMemSize() && totalSubmitted < config.getTotalSubmittedLimit(); } - private void submitAndReset(IODataDescSync desc, NativeBuffer buffer) { - try { - submit(desc, buffer.getBufList()); - } finally { - buffer.reset(false); - } - } - - private void cleanup(boolean force) { + private void cleanup() { if (cleaned) { return; } boolean allReleased = true; - allReleased &= cleanupSubmitted(force); - allReleased &= cleanupConsumed(force); + allReleased &= cleanupSubmitted(true); + allReleased &= cleanupConsumed(true); if (allReleased) { cleaned = true; } @@ -191,12 +207,13 @@ private void cleanup(boolean force) { @Override public void close() { iw.close(); + if (writerMap != null && cleaned) { + writerMap.remove(this); + writerMap = null; + } } - private void waitCompletion(boolean force) throws Exception { - if (!force) { - return; - } + private void waitCompletion() throws Exception { try { while (totalSubmitted > 0) { waitForCondition(config.getWaitTimeMs()); @@ -233,11 +250,7 @@ protected boolean validateReturned(LinkedTaskContext context) throws IOException @Override protected boolean consumed(LinkedTaskContext context) { - // release write buffers - @SuppressWarnings("unchecked") - List bufList = (List) context.morePara; - bufList.forEach(b -> b.release()); - bufList.clear(); + context.desc.release(); return true; } @@ -252,46 +265,75 @@ public void flush(int partitionId) throws IOException { if (buffer == null) { return; } - IODataDescSync desc = buffer.createUpdateDesc(); - if (desc == null) { + List descList = buffer.createUpdateDescs(); + flush(buffer, descList); + } + + @Override + public void flushAll(int partitionId) throws IOException { + NativeBuffer buffer = partitionBufArray[partitionId]; + if (buffer == null) { return; } - totalWriteTimes++; - if (config.isWarnSmallWrite() && buffer.getRoundSize() < config.getMinSize()) { - LOG.warn("too small partition size {}, shuffle {}, map {}, partition {}", - buffer.getRoundSize(), param.getShuffleId(), mapId, partitionId); - } - if (executor == null) { // run write by self - runBySelf(desc, buffer); - return; + List descList = buffer.createUpdateDescs(false); + flush(buffer, descList); + } + + private void flush(NativeBuffer buffer, List descList) throws IOException { + for (IODataDescSync desc : descList) { + totalWriteTimes++; + if (config.isWarnSmallWrite() && buffer.getSubmittedSize() < config.getMinSize()) { + LOG.warn("too small partition size {}, shuffle {}, map {}, partition {}", + buffer.getSubmittedSize(), param.getShuffleId(), mapId, buffer.getPartitionIdKey()); + } + if (executor == null) { // run write by self + runBySelf(desc); + continue; + } + submitToOtherThreads(desc); } - submitToOtherThreads(desc, buffer); + buffer.reset(false); } @Override - public void close() { + protected void waitCompletion() throws IOException { try { - close(true); + DaosWriterSync.this.waitCompletion(); } catch (Exception e) { throw new IllegalStateException("failed to complete all write tasks and cleanup", e); } } - private void close(boolean force) throws Exception { - if (partitionBufArray != null) { - waitCompletion(force); - partitionBufArray = null; - object = null; - if (LOG.isDebugEnabled()) { - LOG.debug("total writes: " + totalWriteTimes + ", total timeout times: " + totalTimeoutTimes + - ", total write-by-self times: " + totalBySelfTimes + ", total timeout times/total writes: " + - ((float) totalTimeoutTimes) / totalWriteTimes); + @Override + public void flushAll() throws IOException { + for (int i = 0; i < partitionBufArray.length; i++) { + NativeBuffer buffer = partitionBufArray[i]; + if (buffer == null) { + continue; } + List descList = buffer.createUpdateDescs(false); + flush(buffer, descList); } - cleanup(force); - if (writerMap != null && (force || cleaned)) { - writerMap.remove(this); - writerMap = null; + waitCompletion(); + } + + @Override + public void close() { + try { + if (partitionBufArray != null) { + waitCompletion(); + partitionBufArray = null; + object = null; + if (LOG.isDebugEnabled()) { + LOG.debug("total writes: " + totalWriteTimes + ", total timeout times: " + totalTimeoutTimes + + ", total write-by-self times: " + totalBySelfTimes + ", total timeout times/total writes: " + + ((float) totalTimeoutTimes) / totalWriteTimes); + } + } + cleanup(); + super.close(); + } catch (Exception e) { + throw new IllegalStateException("failed to complete all write tasks and cleanup", e); } } } @@ -356,17 +398,14 @@ static final class WriteTaskContext extends LinkedTaskContext { * condition to signal caller thread * @param desc * desc object to describe where to write data - * @param bufList + * @param morePara * list of buffers to write to DAOS */ WriteTaskContext(DaosObject object, AtomicInteger counter, Lock writeLock, Condition notFull, - IODataDescSync desc, Object bufList) { + IODataDescSync desc, Object morePara) { super(object, counter, writeLock, notFull); this.desc = desc; - @SuppressWarnings("unchecked") - List myBufList = new ArrayList<>(); - myBufList.addAll((List) bufList); - this.morePara = myBufList; + this.morePara = morePara; } @Override @@ -375,17 +414,6 @@ public WriteTaskContext getNext() { WriteTaskContext ctx = (WriteTaskContext) next; return ctx; } - - @Override - public void reuse(IODataDescSync desc, Object morePara) { - @SuppressWarnings("unchecked") - List myBufList = (List) this.morePara; - if (!myBufList.isEmpty()) { - throw new IllegalStateException("bufList in reusing write task context should be empty"); - } - myBufList.addAll((List) morePara); - super.reuse(desc, myBufList); - } } /** diff --git a/shuffle-daos/src/main/java/org/apache/spark/shuffle/daos/IOManager.java b/shuffle-daos/src/main/java/org/apache/spark/shuffle/daos/IOManager.java index deeab0e6..4c8fa8a4 100644 --- a/shuffle-daos/src/main/java/org/apache/spark/shuffle/daos/IOManager.java +++ b/shuffle-daos/src/main/java/org/apache/spark/shuffle/daos/IOManager.java @@ -23,11 +23,16 @@ package org.apache.spark.shuffle.daos; +import io.daos.DaosObjClassHint; +import io.daos.DaosObjectClass; +import io.daos.DaosObjectType; import io.daos.obj.DaosObjClient; import io.daos.obj.DaosObject; import io.daos.obj.DaosObjectException; import io.daos.obj.DaosObjectId; import org.apache.spark.SparkConf; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; import java.io.IOException; import java.util.Map; @@ -40,6 +45,8 @@ public abstract class IOManager { protected DaosObjClient objClient; + private static Logger log = LoggerFactory.getLogger(IOManager.class); + protected IOManager(SparkConf conf, Map objectMap) { this.conf = conf; this.objectMap = objectMap; @@ -57,9 +64,14 @@ protected DaosObject getObject(long appId, int shuffleId) throws DaosObjectExcep String key = getKey(appId, shuffleId); DaosObject object = objectMap.get(key); if (object == null) { + // we use object class hint instead of object class + // so set object class to UNKNOWN DaosObjectId id = new DaosObjectId(appId, shuffleId); - id.encode(); + id.encode(objClient.getContPtr(), DaosObjectType.DAOS_OT_DKEY_UINT64, + DaosObjectClass.OC_UNKNOWN, + DaosObjClassHint.valueOf(conf.get(package$.MODULE$.SHUFFLE_DAOS_OBJECT_HINT())), 0); object = objClient.getObject(id); + log.info("created new object, oid high: " + object.getOid().getHigh() + ", low: " + object.getOid().getLow()); objectMap.putIfAbsent(key, object); DaosObject activeObject = objectMap.get(key); if (activeObject != object) { // release just created DaosObject @@ -73,6 +85,7 @@ protected DaosObject getObject(long appId, int shuffleId) throws DaosObjectExcep object.open(); } } + log.info("oid high: " + object.getOid().getHigh() + ", low: " + object.getOid().getLow()); return object; } @@ -84,5 +97,7 @@ public void setObjClient(DaosObjClient objClient) { abstract DaosReader getDaosReader(int shuffleId) throws IOException; + abstract DaosReader getDaosParallelReader(int shuffleId) throws IOException; + abstract void close() throws IOException; } diff --git a/shuffle-daos/src/main/java/org/apache/spark/shuffle/daos/IOManagerAsync.java b/shuffle-daos/src/main/java/org/apache/spark/shuffle/daos/IOManagerAsync.java index 893f05e9..fc62d818 100644 --- a/shuffle-daos/src/main/java/org/apache/spark/shuffle/daos/IOManagerAsync.java +++ b/shuffle-daos/src/main/java/org/apache/spark/shuffle/daos/IOManagerAsync.java @@ -79,6 +79,17 @@ DaosReader getDaosReader(int shuffleId) throws IOException { return reader; } + @Override + DaosReader getDaosParallelReader(int shuffleId) throws IOException { + long appId = parseAppId(conf.getAppId()); + if (logger.isDebugEnabled()) { + logger.debug("getting daosparallelreader for app id: " + appId + ", shuffle id: " + shuffleId); + } + DaosParallelReaderAsync reader = new DaosParallelReaderAsync(getObject(appId, shuffleId), readerConfig); + reader.setReaderMap(readerMap); + return reader; + } + @Override void close() throws IOException { readerMap.keySet().forEach(r -> r.close(true)); diff --git a/shuffle-daos/src/main/java/org/apache/spark/shuffle/daos/IOManagerSync.java b/shuffle-daos/src/main/java/org/apache/spark/shuffle/daos/IOManagerSync.java index c9158435..d43b248e 100644 --- a/shuffle-daos/src/main/java/org/apache/spark/shuffle/daos/IOManagerSync.java +++ b/shuffle-daos/src/main/java/org/apache/spark/shuffle/daos/IOManagerSync.java @@ -28,6 +28,7 @@ import org.apache.spark.launcher.SparkLauncher; import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import org.tukaani.xz.UnsupportedOptionsException; import java.io.IOException; import java.util.Map; @@ -117,6 +118,11 @@ DaosReader getDaosReader(int shuffleId) throws IOException { return reader; } + @Override + DaosReader getDaosParallelReader(int shuffleId) throws IOException { + throw new UnsupportedOptionsException("parallel read is not supported in sync API yet"); + } + @Override void close() throws IOException { if (readerExes != null) { diff --git a/shuffle-daos/src/main/scala/org/apache/spark/shuffle/daos/DaosShuffleManager.scala b/shuffle-daos/src/main/scala/org/apache/spark/shuffle/daos/DaosShuffleManager.scala index 30e870bf..59bb9d60 100644 --- a/shuffle-daos/src/main/scala/org/apache/spark/shuffle/daos/DaosShuffleManager.scala +++ b/shuffle-daos/src/main/scala/org/apache/spark/shuffle/daos/DaosShuffleManager.scala @@ -26,13 +26,15 @@ package org.apache.spark.shuffle.daos import java.lang.reflect.Method import java.util.concurrent.ConcurrentHashMap +import scala.collection.JavaConverters.mapAsJavaMapConverter + import io.daos.DaosClient -import scala.collection.JavaConverters._ -import org.apache.spark.{ShuffleDependency, SparkConf, SparkEnv, TaskContext} +import org.apache.spark.{ShuffleDependency, SparkConf, SparkContext, SparkEnv, TaskContext} import org.apache.spark.internal.{config, Logging} import org.apache.spark.shuffle._ -import org.apache.spark.shuffle.sort.SortShuffleManager.canUseBatchFetch +import org.apache.spark.shuffle.sort.BypassMergeSortShuffleHandle +import org.apache.spark.shuffle.sort.SortShuffleManager.MAX_SHUFFLE_OUTPUT_PARTITIONS_FOR_SERIALIZED_MODE import org.apache.spark.util.{ShutdownHookManager, Utils} import org.apache.spark.util.collection.OpenHashSet @@ -51,6 +53,13 @@ class DaosShuffleManager(conf: SparkConf) extends ShuffleManager with Logging { config.SHUFFLE_USE_OLD_FETCH_PROTOCOL.key) } + if (conf.get(config.MEMORY_OFFHEAP_ENABLED)) { + throw new IllegalArgumentException("DaosShuffleManager doesn't support offheap memory in MemoryManager. Please" + + " disable " + config.MEMORY_OFFHEAP_ENABLED) + } + + val shuffleIdSet = ConcurrentHashMap.newKeySet[Integer]() + def findHadoopFs: Method = { try { val fsClass = Utils.classForName("org.apache.hadoop.fs.FileSystem") @@ -117,9 +126,22 @@ class DaosShuffleManager(conf: SparkConf) extends ShuffleManager with Logging { */ override def registerShuffle[K, V, C]( shuffleId: Int, - dependency: ShuffleDependency[K, V, C]): BaseShuffleHandle[K, V, C] + dependency: ShuffleDependency[K, V, C]): ShuffleHandle = { - new BaseShuffleHandle(shuffleId, dependency) + import DaosShuffleManager._ + shuffleIdSet.add(shuffleId) + if (shouldBypassMergeSort(conf, dependency)) { + // If there are fewer than spark.shuffle.sort.bypassMergeThreshold partitions and we don't + // need map-side aggregation, then write numPartitions files directly and just concatenate + // them at the end. This avoids doing serialization and deserialization twice to merge + // together the spilled files, which would happen with the normal code path. The downside is + // having multiple files open at a time and thus more memory allocated to buffers. + new BypassMergeSortShuffleHandle[K, V]( + shuffleId, dependency.asInstanceOf[ShuffleDependency[K, V, V]]) + } else { + // Otherwise, buffer map outputs in a deserialized form: + new BaseShuffleHandle(shuffleId, dependency) + } } override def getWriter[K, V]( @@ -135,20 +157,6 @@ class DaosShuffleManager(conf: SparkConf) extends ShuffleManager with Logging { } override def getReader[K, C]( - handle: ShuffleHandle, - startPartition: Int, - endPartition: Int, - context: TaskContext, - metrics: ShuffleReadMetricsReporter): DaosShuffleReader[K, C] - = { - val blocksByAddress = SparkEnv.get.mapOutputTracker.getMapSizesByExecutorId( - handle.shuffleId, startPartition, endPartition) - new DaosShuffleReader(handle.asInstanceOf[BaseShuffleHandle[K, _, C]], blocksByAddress, context, - metrics, daosShuffleIO, SparkEnv.get.serializerManager, - shouldBatchFetch = canUseBatchFetch(startPartition, endPartition, context)) - } - - override def getReaderForRange[K, C]( handle: ShuffleHandle, startMapIndex: Int, endMapIndex: Int, @@ -157,25 +165,78 @@ class DaosShuffleManager(conf: SparkConf) extends ShuffleManager with Logging { context: TaskContext, metrics: ShuffleReadMetricsReporter): DaosShuffleReader[K, C] = { - val blocksByAddress = SparkEnv.get.mapOutputTracker.getMapSizesByRange( - handle.shuffleId, startMapIndex, endMapIndex, startPartition, endPartition) - new DaosShuffleReader(handle.asInstanceOf[BaseShuffleHandle[K, _, C]], blocksByAddress, context, - metrics, daosShuffleIO, SparkEnv.get.serializerManager, - shouldBatchFetch = canUseBatchFetch(startPartition, endPartition, context)) + val baseHandle = handle.asInstanceOf[BaseShuffleHandle[K, _, C]] + val part = SparkEnv.get.conf.get(config.SHUFFLE_MIN_NUM_PARTS_TO_HIGHLY_COMPRESS) + val highlyCompressed = baseHandle.dependency.partitioner.numPartitions > part + val blocksByAddress = SparkEnv.get.mapOutputTracker.getMapSizesByExecutorId(handle.shuffleId, startMapIndex, + endMapIndex, startPartition, endPartition) + + new DaosShuffleReader(handle.asInstanceOf[BaseShuffleHandle[K, _, C]], blocksByAddress, highlyCompressed, context, + metrics, daosShuffleIO, SparkEnv.get.serializerManager) } - override def unregisterShuffle(shuffleId: Int): Boolean = { + private def removeShuffle(shuffleId: Int): Unit = { logInfo("unregistering shuffle: " + shuffleId) taskIdMapsForShuffle.remove(shuffleId) daosShuffleIO.removeShuffle(shuffleId) } + override def unregisterShuffle(shuffleId: Int): Boolean = { + if (SparkContext.DRIVER_IDENTIFIER.equals(SparkEnv.get.executorId) && shuffleIdSet.remove(shuffleId)) { + removeShuffle(shuffleId) + true + } else { + false + } + } + override def shuffleBlockResolver: ShuffleBlockResolver = null override def stop(): Unit = { + if (SparkContext.DRIVER_IDENTIFIER.equals(SparkEnv.get.executorId)) { + shuffleIdSet.forEach(i => { + if (shuffleIdSet.contains(i)) { // make sure cleaner is not working on same shuffle id + removeShuffle(i) + } + }) + shuffleIdSet.clear() + } daosShuffleIO.close() finalizer() ShutdownHookManager.removeShutdownHook(finalizer) logInfo("stopped " + classOf[DaosShuffleManager]) } } + +private object DaosShuffleManager extends Logging { + def shouldBypassMergeSort(conf: SparkConf, dep: ShuffleDependency[_, _, _]): Boolean = { + // We cannot bypass sorting if we need to do map-side aggregation. + if (dep.mapSideCombine) { + false + } else { + val bypassMergeThreshold: Int = conf.get(config.SHUFFLE_SORT_BYPASS_MERGE_THRESHOLD) + dep.partitioner.numPartitions <= bypassMergeThreshold + } + } + + def canUseSerializedShuffle(dependency: ShuffleDependency[_, _, _]): Boolean = { + val shufId = dependency.shuffleId + val numPartitions = dependency.partitioner.numPartitions + if (!dependency.serializer.supportsRelocationOfSerializedObjects) { + logDebug(s"Can't use serialized shuffle for shuffle $shufId because the serializer, " + + s"${dependency.serializer.getClass.getName}, does not support object relocation") + false + } else if (dependency.mapSideCombine) { + logDebug(s"Can't use serialized shuffle for shuffle $shufId because we need to do " + + s"map-side aggregation") + false + } else if (numPartitions > MAX_SHUFFLE_OUTPUT_PARTITIONS_FOR_SERIALIZED_MODE) { + logDebug(s"Can't use serialized shuffle for shuffle $shufId because it has more than " + + s"$MAX_SHUFFLE_OUTPUT_PARTITIONS_FOR_SERIALIZED_MODE partitions") + false + } else { + logDebug(s"Can use serialized shuffle for shuffle $shufId") + true + } + } +} diff --git a/shuffle-daos/src/main/scala/org/apache/spark/shuffle/daos/DaosShuffleReader.scala b/shuffle-daos/src/main/scala/org/apache/spark/shuffle/daos/DaosShuffleReader.scala index 1fa81592..08b02048 100644 --- a/shuffle-daos/src/main/scala/org/apache/spark/shuffle/daos/DaosShuffleReader.scala +++ b/shuffle-daos/src/main/scala/org/apache/spark/shuffle/daos/DaosShuffleReader.scala @@ -31,14 +31,15 @@ import org.apache.spark.storage.{BlockId, BlockManagerId} import org.apache.spark.util.CompletionIterator import org.apache.spark.util.collection.ExternalSorter + class DaosShuffleReader[K, C]( handle: BaseShuffleHandle[K, _, C], blocksByAddress: Iterator[(BlockManagerId, Seq[(BlockId, Long, Int)])], + highlyCompressed: Boolean, context: TaskContext, readMetrics: ShuffleReadMetricsReporter, shuffleIO: DaosShuffleIO, - serializerManager: SerializerManager = SparkEnv.get.serializerManager, - shouldBatchFetch: Boolean = false) + serializerManager: SerializerManager = SparkEnv.get.serializerManager) extends ShuffleReader[K, C] with Logging { private val dep = handle.dependency @@ -48,19 +49,21 @@ class DaosShuffleReader[K, C]( private val daosReader = shuffleIO.getDaosReader(handle.shuffleId) override def read(): Iterator[Product2[K, C]] = { - val maxBytesInFlight = conf.get(SHUFFLE_DAOS_READ_MAX_BYTES_IN_FLIGHT) - val wrappedStreams = new ShufflePartitionIterator( - context, - blocksByAddress, - serializerManager.wrapStream, - maxBytesInFlight, - conf.get(config.MAX_REMOTE_BLOCK_SIZE_FETCH_TO_MEM), - conf.get(config.SHUFFLE_DETECT_CORRUPT), - conf.get(config.SHUFFLE_DETECT_CORRUPT_MEMORY), - readMetrics, - daosReader, - shouldBatchFetch - ).toCompletionIterator + val maxBytesInFlight = conf.get(SHUFFLE_DAOS_READ_MAX_BYTES_IN_FLIGHT) * 1024 + val daosReader = if (!highlyCompressed) shuffleIO.getDaosReader(handle.shuffleId) + else shuffleIO.getDaosParallelReader(handle.shuffleId) + val iterator = new ShufflePartitionIterator( + context, + blocksByAddress, + serializerManager.wrapStream, + maxBytesInFlight, + conf.get(config.MAX_REMOTE_BLOCK_SIZE_FETCH_TO_MEM), + conf.get(config.SHUFFLE_DETECT_CORRUPT), + conf.get(config.SHUFFLE_DETECT_CORRUPT_MEMORY), + readMetrics, + daosReader + ) + val wrappedStreams = iterator.toCompletionIterator val serializerInstance = dep.serializer.newInstance() diff --git a/shuffle-daos/src/main/scala/org/apache/spark/shuffle/daos/DaosShuffleWriter.scala b/shuffle-daos/src/main/scala/org/apache/spark/shuffle/daos/DaosShuffleWriter.scala index e55dd043..7fb2bbb6 100644 --- a/shuffle-daos/src/main/scala/org/apache/spark/shuffle/daos/DaosShuffleWriter.scala +++ b/shuffle-daos/src/main/scala/org/apache/spark/shuffle/daos/DaosShuffleWriter.scala @@ -23,10 +23,11 @@ package org.apache.spark.shuffle.daos -import org.apache.spark.{SparkEnv, TaskContext} +import org.apache.spark.TaskContext import org.apache.spark.internal.Logging import org.apache.spark.scheduler.MapStatus import org.apache.spark.shuffle.{BaseShuffleHandle, ShuffleWriter} +import org.apache.spark.storage.BlockManagerId class DaosShuffleWriter[K, V, C]( handle: BaseShuffleHandle[K, V, C], @@ -43,19 +44,18 @@ class DaosShuffleWriter[K, V, C]( private var mapStatus: MapStatus = null - private val blockManager = SparkEnv.get.blockManager + private val dummyBlkId = BlockManagerId("-1", "dummy-host", 1024) override def write(records: Iterator[Product2[K, V]]): Unit = { -// val start = System.nanoTime() partitionsWriter = if (dep.mapSideCombine) { new MapPartitionsWriter[K, V, C]( handle.shuffleId, context, + shuffleIO, dep.aggregator, Some(dep.partitioner), dep.keyOrdering, - dep.serializer, - shuffleIO) + dep.serializer) } else { // In this case we pass neither an aggregator nor an ordering to the sorter, because we don't // care whether the keys get sorted in each partition; that will be done on the reduce side @@ -63,18 +63,15 @@ class DaosShuffleWriter[K, V, C]( new MapPartitionsWriter[K, V, V]( handle.shuffleId, context, + shuffleIO, aggregator = None, Some(dep.partitioner), ordering = None, - dep.serializer, - shuffleIO) + dep.serializer) } partitionsWriter.insertAll(records) - val partitionLengths = partitionsWriter.commitAll - - // logInfo(context.taskAttemptId() + " all time: " + (System.nanoTime() - start)/1000000) - - mapStatus = MapStatus(blockManager.shuffleServerId, partitionLengths, mapId) + val partitionLengths = getPartitionLengths + mapStatus = MapStatus(dummyBlkId, partitionLengths, mapId) } override def stop(success: Boolean): Option[MapStatus] = { @@ -98,4 +95,8 @@ class DaosShuffleWriter[K, V, C]( } } } + + def getPartitionLengths(): Array[Long] = { + partitionsWriter.commitAll + } } diff --git a/shuffle-daos/src/main/scala/org/apache/spark/shuffle/daos/MapPartitionsWriter.scala b/shuffle-daos/src/main/scala/org/apache/spark/shuffle/daos/MapPartitionsWriter.scala index b3c10483..4d98821a 100644 --- a/shuffle-daos/src/main/scala/org/apache/spark/shuffle/daos/MapPartitionsWriter.scala +++ b/shuffle-daos/src/main/scala/org/apache/spark/shuffle/daos/MapPartitionsWriter.scala @@ -25,37 +25,43 @@ package org.apache.spark.shuffle.daos import java.util.Comparator -import org.apache.spark._ +import org.apache.spark.{Aggregator, Partitioner, SparkConf, SparkEnv, TaskContext} import org.apache.spark.internal.Logging -import org.apache.spark.memory.{MemoryConsumer, TaskMemoryManager} +import org.apache.spark.memory.{MemoryConsumer, MemoryMode, TaskMemoryManager} import org.apache.spark.serializer.Serializer +import org.apache.spark.shuffle.ShuffleWriteMetricsReporter +import org.apache.spark.shuffle.daos.MapPartitionsWriter._ class MapPartitionsWriter[K, V, C]( shuffleId: Int, context: TaskContext, + shuffleIO: DaosShuffleIO, aggregator: Option[Aggregator[K, V, C]] = None, partitioner: Option[Partitioner] = None, ordering: Option[Ordering[K]] = None, - serializer: Serializer = SparkEnv.get.serializer, - shuffleIO: DaosShuffleIO) extends Logging { + serializer: Serializer = SparkEnv.get.serializer) extends Logging { private val conf = SparkEnv.get.conf + val spillFirst = conf.get(SHUFFLE_DAOS_SPILL_FIRST) + val lowGrantWatermark = if (spillFirst) { + val exeMem = conf.get(org.apache.spark.internal.config.EXECUTOR_MEMORY) * 1024 * 1024 + val memFraction = conf.get(org.apache.spark.internal.config.MEMORY_FRACTION) + val execCores = conf.get(org.apache.spark.internal.config.EXECUTOR_CORES) + val cpusPerTask = conf.get(org.apache.spark.internal.config.CPUS_PER_TASK) + val maxMemPerTask = (exeMem - 300 * 1024 * 1024) * memFraction * cpusPerTask / execCores + (maxMemPerTask * conf.get(SHUFFLE_DAOS_SPILL_GRANT_PCT)).toLong + } else { + 0L + } + logInfo("lowGrantWatermark: " + lowGrantWatermark) + private val numPartitions = partitioner.map(_.numPartitions).getOrElse(1) private val shouldPartition = numPartitions > 1 private def getPartition(key: K): Int = { if (shouldPartition) partitioner.get.getPartition(key) else 0 } - private val serializerManager = SparkEnv.get.serializerManager - private val serInstance = serializer.newInstance() - - private val daosWriter = shuffleIO.getDaosWriter( - numPartitions, - shuffleId, - context.taskAttemptId()) - private val writeMetrics = context.taskMetrics().shuffleWriteMetrics - /* key comparator if map-side combiner is defined */ private val keyComparator: Comparator[K] = ordering.getOrElse((a: K, b: K) => { val h1 = if (a == null) 0 else a.hashCode() @@ -72,17 +78,18 @@ class MapPartitionsWriter[K, V, C]( } // buffer by partition - @volatile var writeBuffer = new PartitionsBuffer[K, C]( + @volatile var writeBuffer = new PartitionsBuffer( + shuffleId, numPartitions, + aggregator, comparator, conf, - context.taskMemoryManager()) + context.taskMemoryManager(), + shuffleIO, + serializer) private[this] var _elementsRead = 0 - private var _writtenBytes = 0L - def writtenBytes: Long = _writtenBytes - def peakMemoryUsedBytes: Long = writeBuffer.peakSize def insertAll(records: Iterator[Product2[K, V]]): Unit = { @@ -109,19 +116,18 @@ class MapPartitionsWriter[K, V, C]( writeBuffer.insert(getPartition(kv._1), kv._1, kv._2.asInstanceOf[C]) } } - // logInfo(context.taskAttemptId() + " insert time: " + (System.nanoTime() - start)/1000000) } def commitAll: Array[Long] = { writeBuffer.flushAll writeBuffer.close - daosWriter.flushAll() - daosWriter.getPartitionLens(numPartitions) + writeBuffer.daosWriter.flushAll() + writeBuffer.daosWriter.getPartitionLens(numPartitions) } def close: Unit = { // serialize rest of records - daosWriter.close + writeBuffer.daosWriter.close } protected def addElementsRead(): Unit = { _elementsRead += 1 } @@ -134,62 +140,67 @@ class MapPartitionsWriter[K, V, C]( * @tparam K * @tparam C */ - private[daos] class PartitionsBuffer[K, C]( + private[daos] class PartitionsBuffer( + val shuffleId: Int, numPartitions: Int, + val aggregator: Option[Aggregator[K, V, C]], val keyComparator: Option[Comparator[K]], val conf: SparkConf, - val taskMemManager: TaskMemoryManager) extends MemoryConsumer(taskMemManager) { - private val partBufferThreshold = conf.get(SHUFFLE_DAOS_WRITE_PARTITION_BUFFER_SIZE).toInt * 1024 - private val totalBufferThreshold = conf.get(SHUFFLE_DAOS_WRITE_BUFFER_SIZE).toInt * 1024 * 1024 + val taskMemManager: TaskMemoryManager, + val shuffleIO: DaosShuffleIO, + val serializer: Serializer) extends MemoryConsumer(taskMemManager, taskMemManager.pageSizeBytes(), + MemoryMode.ON_HEAP) { private val totalBufferInitial = conf.get(SHUFFLE_DAOS_WRITE_BUFFER_INITIAL_SIZE).toInt * 1024 * 1024 private val forceWritePct = conf.get(SHUFFLE_DAOS_WRITE_BUFFER_FORCE_PCT) - private val totalWriteValve = totalBufferThreshold * forceWritePct private val partMoveInterval = conf.get(SHUFFLE_DAOS_WRITE_PARTITION_MOVE_INTERVAL) private val totalWriteInterval = conf.get(SHUFFLE_DAOS_WRITE_TOTAL_INTERVAL) - private val totalPartRatio = totalWriteInterval / partMoveInterval private[daos] val sampleStat = new SampleStat + // track spill status + private[daos] var merging = false + private var lowGranted = 0 + + val taskContext = context + val serializerManager = SparkEnv.get.serializerManager + val serInstance = serializer.newInstance() + val writeMetrics = context.taskMetrics().shuffleWriteMetrics + val spillWriteMetrics = new DummyShuffleWriteMetrics + + val needSpill = aggregator.isDefined + val daosWriter = shuffleIO.getDaosWriter( + numPartitions, + shuffleId, + context.taskAttemptId()) + if (needSpill) { + daosWriter.enableSpill() + } + if (log.isDebugEnabled()) { - log.debug("partBufferThreshold: " + partBufferThreshold) - log.debug("totalBufferThreshold: " + totalBufferThreshold) log.debug("totalBufferInitial: " + totalBufferInitial) log.debug("forceWritePct: " + forceWritePct) - log.debug("totalWriteValve: " + totalWriteValve) log.debug("partMoveInterval: " + partMoveInterval) log.debug("totalWriteInterval: " + totalWriteInterval) } - if (totalBufferInitial > totalBufferThreshold) { - throw new IllegalArgumentException("total buffer initial size (" + totalBufferInitial + ") should be no more " + - "than total buffer threshold (" + totalBufferThreshold + ").") - } - - if (totalPartRatio == 0) { - throw new IllegalArgumentException("totalWriteInterval (" + totalWriteInterval + ") should be no less than" + - " partMoveInterval (" + partMoveInterval) - } - private var totalSize = 0L private var memoryLimit = totalBufferInitial * 1L private var largestSize = 0L var peakSize = 0L - private def initialize[T >: Linked[K, C] with SizeAware[K, C]](): - (T, T, Array[SizeAwareMap[K, C]], Array[SizeAwareBuffer[K, C]]) = { + private def initialize[T >: Linked[K, V, C] with SizeAware[K, V, C]](): + (T, T, Array[SizeAwarePartMap], Array[SizeAwarePartBuffer]) = { // create virtual partition head and end, as well as all linked partitions val (partitionMapArray, partitionBufferArray) = if (comparator.isDefined) { - (new Array[SizeAwareMap[K, C]](numPartitions), null) + (new Array[SizeAwarePartMap](numPartitions), null) } else { - (null, new Array[SizeAwareBuffer[K, C]](numPartitions)) + (null, new Array[SizeAwarePartBuffer](numPartitions)) } val (head, end) = if (comparator.isDefined) { - val mapHead = new SizeAwareMap[K, C](-1, partBufferThreshold, - totalBufferInitial, taskMemManager, this) - val mapEnd = new SizeAwareMap[K, C](-2, partBufferThreshold, - totalBufferInitial, taskMemManager, this) + val mapHead = new SizeAwarePartMap(-1, this) + val mapEnd = new SizeAwarePartMap(-2, this) (0 until numPartitions).foreach(i => { - val map = new SizeAwareMap[K, C](i, partBufferThreshold, totalBufferInitial, taskMemManager, this) + val map = new SizeAwarePartMap(i, this) partitionMapArray(i) = map if (i > 0) { val prevMap = partitionMapArray(i - 1) @@ -199,12 +210,10 @@ class MapPartitionsWriter[K, V, C]( }) (mapHead, mapEnd) } else { - val bufferHead = new SizeAwareBuffer[K, C](-1, partBufferThreshold, - totalBufferInitial, taskMemManager, this) - val bufferEnd = new SizeAwareBuffer[K, C](-2, partBufferThreshold, - totalBufferInitial, taskMemManager, this) + val bufferHead = new SizeAwarePartBuffer(-1, this) + val bufferEnd = new SizeAwarePartBuffer(-2, this) (0 until numPartitions).foreach(i => { - val buffer = new SizeAwareBuffer[K, C](i, partBufferThreshold, totalBufferInitial, taskMemManager, this) + val buffer = new SizeAwarePartBuffer(i, this) partitionBufferArray(i) = buffer if (i > 0) { val prevBuffer = partitionBufferArray(i - 1) @@ -227,7 +236,7 @@ class MapPartitionsWriter[K, V, C]( private val (head, end, partitionMapArray, partitionBufferArray) = initialize() - private def moveToFirst(node: Linked[K, C] with SizeAware[K, C]): Unit = { + private def moveToFirst(node: Linked[K, V, C] with SizeAware[K, V, C]): Unit = { if (head.next != node) { // remove node from list node.prev.next = node.next @@ -237,12 +246,10 @@ class MapPartitionsWriter[K, V, C]( head.next.prev = node head.next = node node.prev = head - // set largestSize - largestSize = head.next.estimatedSize } } - private def moveToLast(node: Linked[K, C] with SizeAware[K, C]): Unit = { + private def moveToLast(node: Linked[K, V, C] with SizeAware[K, V, C]): Unit = { if (end.prev != node) { // remove node from list node.prev.next = node.next @@ -261,7 +268,7 @@ class MapPartitionsWriter[K, V, C]( if (estSize == 0 || map.numOfRecords % partMoveInterval == 0) { movePartition(estSize, map) } - if (sampleStat.numUpdates % totalWriteInterval == 0) { + if (totalSize > memoryLimit & (sampleStat.numUpdates % totalWriteInterval == 0)) { // check if total buffer exceeds memory limit maybeWriteTotal() } @@ -273,13 +280,13 @@ class MapPartitionsWriter[K, V, C]( if (estSize == 0 || buffer.numOfRecords % partMoveInterval == 0) { movePartition(estSize, buffer) } - if (sampleStat.numUpdates % totalWriteInterval == 0) { + if (totalSize > memoryLimit & (sampleStat.numUpdates % totalWriteInterval == 0)) { // check if total buffer exceeds memory limit maybeWriteTotal() } } - def movePartition[T <: SizeAware[K, C] with Linked[K, C]](estSize: Long, buffer: T): Unit = { + def movePartition[T <: SizeAware[K, V, C] with Linked[K, V, C]](estSize: Long, buffer: T): Unit = { if (estSize > largestSize) { largestSize = estSize moveToFirst(buffer) @@ -288,33 +295,62 @@ class MapPartitionsWriter[K, V, C]( } } - private def writeFromHead: Unit = { + /** + * move range of nodes from start to until nextNotFlushed which is before the end + * @param nextNotFlushed + */ + private def moveFromStartToLast(nextNotFlushed: Linked[K, V, C] with SizeAware[K, V, C]): Unit = { + val start = head.next + val last = nextNotFlushed.prev + if (start != last) { + // de-link + head.next = nextNotFlushed + nextNotFlushed.prev = head + // move to last + nextNotFlushed.next = start + start.prev = nextNotFlushed + last.next = end + end.prev = last + } else { + moveToLast(start) + } + } + + /** + * At least one node being written. + * @param size + * @return + */ + private def writeFromHead(size: Long): Long = { var buffer = head.next - var count = 0 - var totalSize = 0L - while (buffer != end && count < totalPartRatio) { - totalSize += buffer.estimatedSize - buffer.writeAndFlush - val emptyBuffer = buffer + var totalWritten = 0L + while ((buffer != end) & totalWritten < size) { + totalWritten += buffer.writeAndFlush buffer = buffer.next - moveToLast(emptyBuffer) - count += 1 } + if (buffer != end) { + largestSize = buffer.estimatedSize + moveFromStartToLast(buffer) + } else { + largestSize = 0L + } + totalWritten } private def maybeWriteTotal(): Unit = { - // write some partition out if total size is bigger than valve - if (totalSize > totalWriteValve) { - writeFromHead - } - if (totalSize > memoryLimit) { - val limit = Math.min(2 * totalSize, totalBufferThreshold) - val memRequest = limit - memoryLimit - val granted = acquireMemory(memRequest) - memoryLimit += granted - if (totalSize >= memoryLimit) { - writeFromHead + val limit = 2 * totalSize + val memRequest = limit - memoryLimit + val granted = acquireMemory(memRequest) + memoryLimit += granted + if (granted < memRequest) { + lowGranted += 1 + if (!spillFirst) { + writeFromHead(memRequest - granted) + } else if (granted < lowGrantWatermark | (lowGranted >= 2)) { + writeFromHead(totalSize - totalBufferInitial) } + } else { + lowGranted = 0 } } @@ -326,121 +362,55 @@ class MapPartitionsWriter[K, V, C]( } def releaseMemory(memory: Long): Unit = { + freeMemory(Math.min(memory, memoryLimit - totalBufferInitial)) memoryLimit -= memory + if (memoryLimit < totalBufferInitial) { + memoryLimit = totalBufferInitial + } } def flushAll: Unit = { val buffer = if (comparator.isDefined) partitionMapArray else partitionBufferArray - buffer.foreach(e => e.writeAndFlush) - } - - def close: Unit = { - val buffer = if (comparator.isDefined) partitionMapArray else partitionBufferArray - buffer.foreach(b => b.close) - } - - def spill(size: Long, trigger: MemoryConsumer): Long = ??? - } - - private[daos] trait SizeAware[K, C] { - this: MemoryConsumer => - - protected var writeCount = 0 - - protected var lastSize = 0L - - protected var _pairsWriter: PartitionOutput = null - - def partitionId: Int - - def writeThreshold: Int - - def estimatedSize: Long - - def totalBufferInitial: Long - - def iterator: Iterator[(K, C)] - - def reset: Unit - - def parent: PartitionsBuffer[K, C] - - def pairsWriter: PartitionOutput - - def updateTotalSize(estSize: Long): Unit = { - val diff = estSize - lastSize - if (diff > 0) { - lastSize = estSize - parent.updateTotalSize(diff) - } - } - - def releaseMemory(memory: Long): Unit = { - freeMemory(memory) - parent.releaseMemory(memory) - } - - private def writeAndFlush(memory: Long): Unit = { - val writer = if (_pairsWriter != null) _pairsWriter else pairsWriter - var count = 0 - iterator.foreach(p => { - writer.write(p._1, p._2) - count += 1 - }) - if (count > 0) { - writer.flush // force write - writeCount += count - lastSize = 0 - parent.updateTotalSize(-memory) - releaseMemory(memory - totalBufferInitial) - reset - } - } - - def writeAndFlush: Unit = { - writeAndFlush(estimatedSize) - } - - def maybeWrite(memory: Long): Boolean = { - if (memory < writeThreshold) { - false - } else { - writeAndFlush(memory) - true - } - } - - def afterUpdate(estSize: Long): Long = { - if (maybeWrite(estSize)) { - 0L + if (!needSpill) { + buffer.foreach(e => { + e.writeAndFlush + e.close + }) } else { - updateTotalSize(estSize) - estSize + // no more spill for existing in-mem data + daosWriter.startMerging() + merging = true + var totalDiskSpilled = 0L + var totalMemSpilled = 0L + buffer.foreach(e => { + totalDiskSpilled += e.merge + totalMemSpilled += e.spillMemSize + }) + context.taskMetrics().incDiskBytesSpilled(totalDiskSpilled) + context.taskMetrics().incMemoryBytesSpilled(totalMemSpilled) } + context.taskMetrics().incPeakExecutionMemory(peakSize) } def close: Unit = { - if (_pairsWriter != null) { - _pairsWriter.close - _pairsWriter = null + val allocated = memoryLimit - totalBufferInitial + // partitions already closed in flushAll + if (allocated > 0) { + freeMemory(allocated) } } - } - private[daos] trait Linked[K, C] { - this: SizeAware[K, C] => - - var prev: Linked[K, C] with SizeAware[K, C] = null - var next: Linked[K, C] with SizeAware[K, C] = null + override def spill(size: Long, trigger: MemoryConsumer): Long = { + 0L + } } - private class SizeAwareMap[K, C]( + private class SizeAwarePartMap( val partitionId: Int, - val writeThreshold: Int, - val totalBufferInitial: Long, - taskMemoryManager: TaskMemoryManager, - val parent: PartitionsBuffer[K, C]) extends MemoryConsumer(taskMemoryManager) - with Linked[K, C] with SizeAware[K, C] { + val parent: PartitionsBuffer) extends + { + val pairsDefaultWriter = new PartitionOutput[K, V, C](partitionId, parent, parent.writeMetrics) + } with Linked[K, V, C] with SizeAware[K, V, C] { private var map = new SizeSamplerAppendOnlyMap[K, C](parent.sampleStat) private var _estSize: Long = _ @@ -453,7 +423,7 @@ class MapPartitionsWriter[K, V, C]( afterUpdate(_estSize) } - def numOfRecords: Int = map.numOfRecords + override def numOfRecords: Int = map.numOfRecords def reset: Unit = { map = new SizeSamplerAppendOnlyMap[K, C](parent.sampleStat) @@ -463,29 +433,14 @@ class MapPartitionsWriter[K, V, C]( def iterator(): Iterator[(K, C)] = { map.destructiveSortedIterator(parent.keyComparator.get) } - - def spill(size: Long, trigger: MemoryConsumer): Long = { - val curSize = _estSize - writeAndFlush - curSize - } - - def pairsWriter: PartitionOutput = { - if (_pairsWriter == null) { - _pairsWriter = new PartitionOutput(shuffleId, context.taskAttemptId(), partitionId, serializerManager, - serInstance, daosWriter, writeMetrics) - } - _pairsWriter - } } - private class SizeAwareBuffer[K, C]( + private class SizeAwarePartBuffer( val partitionId: Int, - val writeThreshold: Int, - val totalBufferInitial: Long, - taskMemoryManager: TaskMemoryManager, - val parent: PartitionsBuffer[K, C]) extends MemoryConsumer(taskMemoryManager) - with Linked[K, C] with SizeAware[K, C] { + val parent: PartitionsBuffer) extends + { + val pairsDefaultWriter = new PartitionOutput[K, V, C](partitionId, parent, parent.writeMetrics) + } with Linked[K, V, C] with SizeAware[K, V, C] { private var buffer = new SizeSamplerPairBuffer[K, C](parent.sampleStat) private var _estSize: Long = _ @@ -498,7 +453,7 @@ class MapPartitionsWriter[K, V, C]( afterUpdate(_estSize) } - def numOfRecords: Int = buffer.numOfRecords + override def numOfRecords: Int = buffer.numOfRecords def reset: Unit = { buffer = new SizeSamplerPairBuffer[K, C](parent.sampleStat) @@ -508,19 +463,138 @@ class MapPartitionsWriter[K, V, C]( def iterator(): Iterator[(K, C)] = { buffer.iterator() } + } +} + +object MapPartitionsWriter { + + private[daos] trait SizeAware[K, V, C] { + + protected var totalSpillMem = 0L + + protected var lastSize = 0L + + val pairsDefaultWriter: PartitionOutput[K, V, C] + + val partitionId: Int + + val parent: MapPartitionsWriter[K, V, C]#PartitionsBuffer + + val shuffleId: Int = parent.shuffleId + + val daosWriter: DaosWriter = parent.daosWriter + + val aggregator: Option[Aggregator[K, V, C]] = parent.aggregator + + val keyComparator: Option[Comparator[K]] = parent.keyComparator + + val shuffleIO: DaosShuffleIO = parent.shuffleIO + + val serializer = parent.serializer + + def numOfRecords: Int + + def estimatedSize: Long + + def iterator: Iterator[(K, C)] + + def spillMemSize: Long = totalSpillMem + + def reset: Unit + + def updateTotalSize(estSize: Long): Unit = { + val diff = estSize - lastSize + if (diff > 0) { + lastSize = estSize + parent.updateTotalSize(diff) + } + } + + def releaseMemory(memory: Long): Unit = { + parent.releaseMemory(memory) + parent.updateTotalSize(-memory) + } + + def pairsWriter: PartitionOutput[K, V, C] = { + if ((!parent.needSpill) | parent.merging) pairsDefaultWriter + else new PartitionOutput[K, V, C](partitionId, parent, parent.spillWriteMetrics) + } - def spill(size: Long, trigger: MemoryConsumer): Long = { - val curSize = _estSize - writeAndFlush - curSize + def postFlush(memory: Long): Unit = { + lastSize = 0 + releaseMemory(memory) + reset + } + + /** + * supposed to be non-empty buffer. + * + * @param memory + * @return + */ + private def writeAndFlush(memory: Long): Long = { + val pw = pairsWriter + iterator.foreach(p => { + pw.writeAutoFlush(p._1, p._2) + }) + if (pw == pairsDefaultWriter) { + pw.flush // force write + } else { + pw.close // writer for spill, so flush all and close + totalSpillMem += memory + } + postFlush(memory) + memory + } + + def writeAndFlush: Long = { + if (numOfRecords > 0) { + writeAndFlush(estimatedSize) + } else { + 0L + } } - def pairsWriter: PartitionOutput = { - if (_pairsWriter == null) { - _pairsWriter = new PartitionOutput(shuffleId, context.taskAttemptId(), partitionId, serializerManager, - serInstance, daosWriter, writeMetrics) + def merge: Long = { + if (daosWriter.isSpilled(partitionId)) { // partition actually spilled ? + val merger = new PartitionMerger[K, V, C](this, shuffleIO, serializer) + val spilledSize = merger.mergeAndOutput + postFlush(estimatedSize) + close + spilledSize + } else { + writeAndFlush + close + 0L } - _pairsWriter } + + def afterUpdate(estSize: Long): Long = { + updateTotalSize(estSize) + estSize + } + + def close: Unit = { + pairsDefaultWriter.close + } + } + + private[daos] trait Linked[K, V, C] { + this: SizeAware[K, V, C] => + + var prev: Linked[K, V, C] with SizeAware[K, V, C] = null + var next: Linked[K, V, C] with SizeAware[K, V, C] = null + } + + private[daos] class DummyShuffleWriteMetrics extends ShuffleWriteMetricsReporter { + override private[spark] def incBytesWritten(v: Long): Unit = {} + + override private[spark] def incRecordsWritten(v: Long): Unit = {} + + override private[spark] def incWriteTime(v: Long): Unit = {} + + override private[spark] def decBytesWritten(v: Long): Unit = {} + + override private[spark] def decRecordsWritten(v: Long): Unit = {} } } diff --git a/shuffle-daos/src/main/scala/org/apache/spark/shuffle/daos/PartitionMerger.scala b/shuffle-daos/src/main/scala/org/apache/spark/shuffle/daos/PartitionMerger.scala new file mode 100644 index 00000000..9d2588bd --- /dev/null +++ b/shuffle-daos/src/main/scala/org/apache/spark/shuffle/daos/PartitionMerger.scala @@ -0,0 +1,177 @@ +/* + * (C) Copyright 2018-2021 Intel Corporation. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + * GOVERNMENT LICENSE RIGHTS-OPEN SOURCE SOFTWARE + * The Government's rights to use, modify, reproduce, release, perform, display, + * or disclose this software are subject to the terms of the Apache License as + * provided in Contract No. B609815. + * Any reproduction of computer software, computer software documentation, or + * portions thereof marked with this legend must also reproduce the markings. + */ + +package org.apache.spark.shuffle.daos + +import java.util.{Comparator, UUID} + +import scala.collection.JavaConverters._ +import scala.collection.mutable +import scala.collection.mutable.ArrayBuffer + +import org.apache.spark.SparkEnv +import org.apache.spark.executor.TempShuffleReadMetrics +import org.apache.spark.internal.Logging +import org.apache.spark.serializer.Serializer +import org.apache.spark.shuffle.daos.DaosWriter.SpillInfo +import org.apache.spark.shuffle.daos.MapPartitionsWriter.SizeAware +import org.apache.spark.storage.{BlockId, TempShuffleBlockId} + +class PartitionMerger[K, V, C] ( + val part: SizeAware[K, V, C], + val io: DaosShuffleIO, + val serializer: Serializer) extends Logging { + + def mergeAndOutput: Long = { + val pw = part.pairsWriter + val infoList = part.daosWriter.getSpillInfo(part.partitionId) + var totalSpilled = 0L + val iterators = infoList.asScala.map(info => { + totalSpilled += info.getSize + new SpillIterator[K, C](info) + }) ++ Seq(part.iterator) + val it = mergeWithAggregation(iterators, part.aggregator.get.mergeCombiners, part.keyComparator.get) + while (it.hasNext) { + val item = it.next() + pw.writeAutoFlush(item._1, item._2) + } + pw.close + totalSpilled + } + + /** + * Merge a sequence of (K, C) iterators by aggregating values for each key, assuming that each + * iterator is sorted by key with a given comparator. If the comparator is not a total ordering + * (e.g. when we sort objects by hash code and different keys may compare as equal although + * they're not), we still merge them by doing equality tests for all keys that compare as equal. + */ + private def mergeWithAggregation( + iterators: Seq[Iterator[Product2[K, C]]], + mergeCombiners: (C, C) => C, + comparator: Comparator[K]): Iterator[Product2[K, C]] = { + // We only have a partial ordering, e.g. comparing the keys by hash code, which means that + // multiple distinct keys might be treated as equal by the ordering. To deal with this, we + // need to read all keys considered equal by the ordering at once and compare them. + val it = new Iterator[Iterator[Product2[K, C]]] { + val sorted = mergeSort(iterators, comparator).buffered + + // Buffers reused across elements to decrease memory allocation + val keys = new ArrayBuffer[K] + val combiners = new ArrayBuffer[C] + + override def hasNext: Boolean = sorted.hasNext + + override def next(): Iterator[Product2[K, C]] = { + if (!hasNext) { + throw new NoSuchElementException + } + keys.clear() + combiners.clear() + val firstPair = sorted.next() + keys += firstPair._1 + combiners += firstPair._2 + val key = firstPair._1 + while (sorted.hasNext && comparator.compare(sorted.head._1, key) == 0) { + val pair = sorted.next() + var i = 0 + var foundKey = false + while (i < keys.size && !foundKey) { + if (keys(i) == pair._1) { + combiners(i) = mergeCombiners(combiners(i), pair._2) + foundKey = true + } + i += 1 + } + if (!foundKey) { + keys += pair._1 + combiners += pair._2 + } + } + + // Note that we return an iterator of elements since we could've had many keys marked + // equal by the partial order; we flatten this below to get a flat iterator of (K, C). + keys.iterator.zip(combiners.iterator) + } + } + it.flatten + } + + private def mergeSort(iterators: Seq[Iterator[Product2[K, C]]], comparator: Comparator[K]) + : Iterator[Product2[K, C]] = { + val bufferedIters = iterators.filter(_.hasNext).map(_.buffered) + type Iter = BufferedIterator[Product2[K, C]] + // Use the reverse order (compare(y,x)) because PriorityQueue dequeues the max + val heap = new mutable.PriorityQueue[Iter]()( + (x: Iter, y: Iter) => comparator.compare(y.head._1, x.head._1)) + heap.enqueue(bufferedIters: _*) // Will contain only the iterators with hasNext = true + new Iterator[Product2[K, C]] { + override def hasNext: Boolean = heap.nonEmpty + + override def next(): Product2[K, C] = { + if (!hasNext) { + throw new NoSuchElementException + } + val firstBuf = heap.dequeue() + val firstPair = firstBuf.next() + if (firstBuf.hasNext) { + heap.enqueue(firstBuf) + } + firstPair + } + } + } + + class SpillIterator[K, C] ( + val info: SpillInfo) extends Iterator[Product2[K, C]] { + + import PartitionMerger._ + + private val reader = io.getDaosReader(part.shuffleId) + private val map = new java.util.LinkedHashMap[(String, Integer), (java.lang.Long, BlockId)](1) + map.put((info.getMapId, Integer.valueOf(info.getReduceId)), + (info.getSize, dummyBlockId)) + private val serializerManager = SparkEnv.get.serializerManager + private val daosStream = new DaosShuffleInputStream(reader, map, 1 * 1024 * 1024, + 1 * 1024 * 1024, dummyReadMetrics) + private val wrappedStream = serializerManager.wrapStream(dummyBlockId, daosStream) + private val deStream = serializer.newInstance().deserializeStream(wrappedStream) + private val it = deStream.asKeyValueIterator.asInstanceOf[Iterator[(K, C)]] + + override def hasNext: Boolean = { + val ret = it.hasNext + if (!ret) { + deStream.close() + } + ret + } + + override def next(): Product2[K, C] = { + it.next() + } + } +} + +object PartitionMerger { + val dummyBlockId = TempShuffleBlockId(UUID.randomUUID()) + val dummyReadMetrics = new TempShuffleReadMetrics +} diff --git a/shuffle-daos/src/main/scala/org/apache/spark/shuffle/daos/PartitionOutput.scala b/shuffle-daos/src/main/scala/org/apache/spark/shuffle/daos/PartitionOutput.scala index 4c37ca51..de943851 100644 --- a/shuffle-daos/src/main/scala/org/apache/spark/shuffle/daos/PartitionOutput.scala +++ b/shuffle-daos/src/main/scala/org/apache/spark/shuffle/daos/PartitionOutput.scala @@ -25,7 +25,7 @@ package org.apache.spark.shuffle.daos import java.io.OutputStream -import org.apache.spark.serializer.{SerializationStream, SerializerInstance, SerializerManager} +import org.apache.spark.serializer.SerializationStream import org.apache.spark.shuffle.ShuffleWriteMetricsReporter import org.apache.spark.storage.{ShuffleBlockId, TimeTrackingOutputStream} import org.apache.spark.util.Utils @@ -33,23 +33,19 @@ import org.apache.spark.util.Utils /** * Output for each partition. * - * @param shuffleId - * @param mapId * @param partitionId - * @param serializerManager - * @param serializerInstance - * @param daosWriter + * @param parent * @param writeMetrics */ -class PartitionOutput( - shuffleId: Int, - mapId: Long, +class PartitionOutput[K, V, C]( partitionId: Int, - serializerManager: SerializerManager, - serializerInstance: SerializerInstance, - daosWriter: DaosWriter, + parent: MapPartitionsWriter[K, V, C]#PartitionsBuffer, writeMetrics: ShuffleWriteMetricsReporter) { + private val mapId = parent.taskContext.taskAttemptId() + private val serializerManager = parent.serializerManager + private val serializerInstance = parent.serInstance + private var ds: DaosShuffleOutputStream = null private var ts: TimeTrackingOutputStream = null @@ -64,11 +60,15 @@ class PartitionOutput( private var lastWrittenBytes = 0L + private val flushRecords = parent.conf.get(SHUFFLE_DAOS_WRITE_FLUSH_RECORDS) + def open: Unit = { - ds = new DaosShuffleOutputStream(partitionId, daosWriter) + ds = new DaosShuffleOutputStream(partitionId, parent.daosWriter) + parent.daosWriter.incrementSeq(partitionId) ts = new TimeTrackingOutputStream(writeMetrics, ds) - bs = serializerManager.wrapStream(ShuffleBlockId(shuffleId, mapId, partitionId), ts) + bs = serializerManager.wrapStream(ShuffleBlockId(parent.shuffleId, mapId, partitionId), ts) objOut = serializerInstance.serializeStream(bs) + opened = true } @@ -78,6 +78,7 @@ class PartitionOutput( } objOut.writeKey(key) objOut.writeValue(value) + // update metrics numRecordsWritten += 1 writeMetrics.incRecordsWritten(1) @@ -86,6 +87,13 @@ class PartitionOutput( } } + def writeAutoFlush(key: Any, value: Any): Unit = { + write(key, value) + if (numRecordsWritten % flushRecords == 0) { + flush + } + } + private def updateWrittenBytes: Unit = { val writtenBytes = ds.getWrittenBytes writeMetrics.incBytesWritten(writtenBytes - lastWrittenBytes) @@ -95,7 +103,7 @@ class PartitionOutput( def flush: Unit = { if (opened) { objOut.flush() - daosWriter.flush(partitionId) + parent.daosWriter.flush(partitionId) updateWrittenBytes } } diff --git a/shuffle-daos/src/main/scala/org/apache/spark/shuffle/daos/ShufflePartitionIterator.scala b/shuffle-daos/src/main/scala/org/apache/spark/shuffle/daos/ShufflePartitionIterator.scala index 5a767aae..f1b3a4ab 100644 --- a/shuffle-daos/src/main/scala/org/apache/spark/shuffle/daos/ShufflePartitionIterator.scala +++ b/shuffle-daos/src/main/scala/org/apache/spark/shuffle/daos/ShufflePartitionIterator.scala @@ -35,41 +35,44 @@ import org.apache.spark.storage.{BlockId, BlockManagerId, ShuffleBlockBatchId, S import org.apache.spark.util.{CompletionIterator, TaskCompletionListener, Utils} class ShufflePartitionIterator( - context: TaskContext, - blocksByAddress: Iterator[(BlockManagerId, Seq[(BlockId, Long, Int)])], + val context: TaskContext, + val blocksByAddress: Iterator[(BlockManagerId, Seq[(BlockId, Long, Int)])], streamWrapper: (BlockId, InputStream) => InputStream, maxBytesInFlight: Long, maxReqSizeShuffleToMem: Long, detectCorrupt: Boolean, detectCorruptUseExtraMemory: Boolean, shuffleMetrics: ShuffleReadMetricsReporter, - daosReader: DaosReader, - doBatchFetch: Boolean) extends Iterator[(BlockId, InputStream)] with Logging { + daosReader: DaosReader) extends Iterator[(BlockId, InputStream)] with Logging { - private var lastMapReduce: (java.lang.Long, Integer) = _ - private var lastMRBlock: (java.lang.Long, BlockId, BlockManagerId) = _ + protected var lastMapReduce: (String, Integer) = _ + protected var lastMRBlock: (java.lang.Long, BlockId) = _ - private[daos] var inputStream: DaosShuffleInputStream = null + // (mapid, reduceid) -> (length, BlockId) + protected val mapReduceIdMap = new util.LinkedHashMap[(String, Integer), (java.lang.Long, BlockId)] + + protected var mapReduceIt: util.Iterator[util.Map.Entry[(String, Integer), (java.lang.Long, BlockId)]] = _ - // (mapid, reduceid) -> (length, BlockId, BlockManagerId) - private val mapReduceIdMap = new util.LinkedHashMap[(java.lang.Long, Integer), - (java.lang.Long, BlockId, BlockManagerId)] + protected val onCompleteCallback = new ShufflePartitionCompletionListener(this) - private var mapReduceIt: util.Iterator[util.Map.Entry[(java.lang.Long, Integer), - (java.lang.Long, BlockId, BlockManagerId)]] = _ + private[daos] var inputStream: DaosShuffleInputStream = null - private val onCompleteCallback = new ShufflePartitionCompletionListener(this) + val dummyBlkId = BlockManagerId("-1", "dummy-host", 1024) initialize def initialize: Unit = { context.addTaskCompletionListener(onCompleteCallback) startReading + + inputStream = new DaosShuffleInputStream(daosReader, mapReduceIdMap, + maxBytesInFlight, maxReqSizeShuffleToMem, shuffleMetrics) + mapReduceIt = mapReduceIdMap.entrySet().iterator() } - private def getMapReduceId(shuffleBlockId: ShuffleBlockId): (java.lang.Long, Integer) = { + private def getMapReduceId(shuffleBlockId: ShuffleBlockId): (String, Integer) = { val name = shuffleBlockId.name.split("_") - (java.lang.Long.valueOf(name(2)), Integer.valueOf(name(3))) + (name(2), Integer.valueOf(name(3))) } private def startReading: Unit = { @@ -79,22 +82,35 @@ class ShufflePartitionIterator( if (mapReduceIdMap.containsKey(mapReduceId._1)) { throw new IllegalStateException("duplicate map id: " + mapReduceId._1) } - mapReduceIdMap.put((mapReduceId._1, mapReduceId._2), (t3._2, t3._1, t2._1)) + mapReduceIdMap.put((mapReduceId._1, mapReduceId._2), (t3._2, t3._1)) }) }) if (log.isDebugEnabled) { log.debug(s"total mapreduceId: ${mapReduceIdMap.size()}, they are, ") - mapReduceIdMap.forEach((key, value) => logDebug(key.toString() + " = " + value.toString)) + mapReduceIdMap.forEach((key, value) => logDebug(context.taskAttemptId() + ": " + + key.toString() + " = " + value.toString)) } + } - inputStream = new DaosShuffleInputStream(daosReader, mapReduceIdMap, - maxBytesInFlight, maxReqSizeShuffleToMem, shuffleMetrics) - mapReduceIt = mapReduceIdMap.entrySet().iterator() + def throwFetchFailedException( + blockId: BlockId, + address: BlockManagerId, + e: Throwable): Nothing = { + blockId match { + // -1 as mapIndex to avoid removing map output + case ShuffleBlockId(shufId, mapId, reduceId) => + throw new FetchFailedException(address, shufId, mapId, -1, reduceId, e) + case ShuffleBlockBatchId(shuffleId, mapId, startReduceId, _) => + throw new FetchFailedException(address, shuffleId, mapId, -1, startReduceId, e) + case _ => + throw new SparkException( + "Failed to get block " + blockId + ", which is not a shuffle block", e) + } } override def hasNext: Boolean = { - (!inputStream.isCompleted()) && mapReduceIt.hasNext + (!inputStream.isCompleted()) & mapReduceIt.hasNext } override def next(): (BlockId, InputStream) = { @@ -117,38 +133,17 @@ class ShufflePartitionIterator( } catch { case e: IOException => logError(s"got an corrupted block ${inputStream.getCurBlockId} originated from " + - s"${inputStream.getCurOriginAddress}.", e) + s"${dummyBlkId}.", e) throw e } finally { if (input == null) { inputStream.close(false) } } - (lastBlockId, new BufferReleasingInputStream(lastMapReduce, lastMRBlock, input, this, + (lastBlockId, new BufferReleasingInputStream(lastMRBlock, input, this, detectCorrupt && streamCompressedOrEncryptd)) } - def throwFetchFailedException( - blockId: BlockId, - mapIndex: Int, - address: BlockManagerId, - e: Throwable): Nothing = { - blockId match { - case ShuffleBlockId(shufId, mapId, reduceId) => - throw new FetchFailedException(address, shufId, mapId, mapIndex, reduceId, e) - case ShuffleBlockBatchId(shuffleId, mapId, startReduceId, _) => - throw new FetchFailedException(address, shuffleId, mapId, mapIndex, startReduceId, e) - case _ => - throw new SparkException( - "Failed to get block " + blockId + ", which is not a shuffle block", e) - } - } - - def toCompletionIterator: Iterator[(BlockId, InputStream)] = { - CompletionIterator[(BlockId, InputStream), this.type](this, - onCompleteCallback.onTaskCompletion(context)) - } - def cleanup: Unit = { if (inputStream != null) { inputStream.close(false) @@ -156,6 +151,10 @@ class ShufflePartitionIterator( } } + def toCompletionIterator: Iterator[(BlockId, InputStream)] = { + CompletionIterator[(BlockId, InputStream), this.type](this, + onCompleteCallback.onTaskCompletion(context)) + } } /** @@ -163,13 +162,12 @@ class ShufflePartitionIterator( * also detects stream corruption if streamCompressedOrEncrypted is true */ private class BufferReleasingInputStream( - // This is visible for testing - private val mapreduce: (java.lang.Long, Integer), - private val mrblock: (java.lang.Long, BlockId, BlockManagerId), + private val mrblock: (java.lang.Long, BlockId), private val delegate: InputStream, private val iterator: ShufflePartitionIterator, private val detectCorruption: Boolean) extends InputStream { + private[this] var closed = false override def read(): Int = { @@ -178,8 +176,8 @@ private class BufferReleasingInputStream( } catch { case e: IOException if detectCorruption => IOUtils.closeQuietly(this) - iterator.throwFetchFailedException(mrblock._2, mapreduce._1.toInt, - mrblock._3, e) + iterator.throwFetchFailedException(mrblock._2, + iterator.dummyBlkId, e) } } @@ -200,8 +198,8 @@ private class BufferReleasingInputStream( } catch { case e: IOException if detectCorruption => IOUtils.closeQuietly(this) - iterator.throwFetchFailedException(mrblock._2, mapreduce._1.toInt, - mrblock._3, e) + iterator.throwFetchFailedException(mrblock._2, + iterator.dummyBlkId, e) } } @@ -213,8 +211,8 @@ private class BufferReleasingInputStream( } catch { case e: IOException if detectCorruption => IOUtils.closeQuietly(this) - iterator.throwFetchFailedException(mrblock._2, mapreduce._1.toInt, - mrblock._3, e) + iterator.throwFetchFailedException(mrblock._2, + iterator.dummyBlkId, e) } } @@ -224,15 +222,15 @@ private class BufferReleasingInputStream( } catch { case e: IOException if detectCorruption => IOUtils.closeQuietly(this) - iterator.throwFetchFailedException(mrblock._2, mapreduce._1.toInt, - mrblock._3, e) + iterator.throwFetchFailedException(mrblock._2, + iterator.dummyBlkId, e) } } override def reset(): Unit = delegate.reset() } -private class ShufflePartitionCompletionListener(var data: ShufflePartitionIterator) +private[daos] class ShufflePartitionCompletionListener(var data: ShufflePartitionIterator) extends TaskCompletionListener { override def onTaskCompletion(context: TaskContext): Unit = { diff --git a/shuffle-daos/src/main/scala/org/apache/spark/shuffle/daos/SizeSampler.scala b/shuffle-daos/src/main/scala/org/apache/spark/shuffle/daos/SizeSampler.scala index 40e8ffff..8372a5e6 100644 --- a/shuffle-daos/src/main/scala/org/apache/spark/shuffle/daos/SizeSampler.scala +++ b/shuffle-daos/src/main/scala/org/apache/spark/shuffle/daos/SizeSampler.scala @@ -43,24 +43,34 @@ private[spark] trait SizeSampler { /** Total number of insertions and updates into the map since the last resetSamples(). */ private var numUpdates: Long = _ - private var bytesPerUpdate: Double = _ - private var stat: SampleStat = _ + private var buffer: Boolean = _ + protected var curSize = 0 - protected def setSampleStat(stat: SampleStat): Unit = { + protected def setSampleStat(stat: SampleStat, buffer: Boolean): Unit = { this.stat = stat + this.buffer = buffer } /** * Reset samples collected so far. * This should be called after the collection undergoes a dramatic change in size. + * After growing size, it's only called by buffer, not map. */ protected def resetSamples(): Unit = { numUpdates = 1 samples.clear() - takeSample() + var inced = false + if (stat.numUpdates == 0) { + stat.numUpdates = 1 + inced = true + } + takeSample + if (inced) { + stat.numUpdates = 0 + } } protected def afterUpdate(): Unit = { @@ -68,34 +78,38 @@ private[spark] trait SizeSampler { curSize += 1 stat.incNumUpdates if (stat.needSample) { - takeSample() + takeSample } } + /** + * @return number of records consumed + */ def numOfRecords: Int = curSize + /** + * @return number of elements in map or buffer + */ + def size: Int + /** * Take a new sample of the current collection's size. */ - protected def takeSample(): Unit = { + protected def takeSample: Unit = { samples.enqueue(Sample(SizeEstimator.estimate(this), numUpdates)) // Only use the last two samples to extrapolate if (samples.size > 2) { samples.dequeue() } - var updateDelta = 0L val bytesDelta = samples.toList.reverse match { case latest :: previous :: _ => - updateDelta = latest.numUpdates - previous.numUpdates - (latest.size - previous.size).toDouble / updateDelta + val updateDelta = latest.numUpdates - previous.numUpdates + if (buffer) (latest.size - previous.size).toDouble / updateDelta + else latest.size / latest.numUpdates // possible case for map-combine // If fewer than 2 samples, assume no change case _ => 0 } - if (updateDelta == 0) { - return - } - bytesPerUpdate = math.max(0, bytesDelta) - stat.updateStat(bytesPerUpdate, updateDelta) + stat.updateStat(bytesDelta) } /** @@ -103,8 +117,8 @@ private[spark] trait SizeSampler { */ def estimateSize(): Long = { assert(samples.nonEmpty) - val bpu = if (bytesPerUpdate == 0) stat.bytesPerUpdate else bytesPerUpdate - val extrapolatedDelta = bpu * (numUpdates - samples.last.numUpdates) + val nbr = numUpdates - samples.last.numUpdates + val extrapolatedDelta = stat.bytesPerUpdate * nbr (samples.last.size + extrapolatedDelta).toLong } } @@ -121,10 +135,8 @@ private[spark] class SampleStat { private[daos] var nextSampleNum: Long = 1 private[daos] var bytesPerUpdate: Double = 0 - def updateStat(partBpu: Double, partUpdateDelta: Long): Unit = { - bytesPerUpdate = ((numUpdates - partUpdateDelta) * bytesPerUpdate + - partUpdateDelta * partBpu - ) / numUpdates + def updateStat(partBpu: Double): Unit = { + bytesPerUpdate = partBpu lastNumUpdates = numUpdates nextSampleNum = math.ceil(numUpdates * SAMPLE_GROWTH_RATE).toLong } diff --git a/shuffle-daos/src/main/scala/org/apache/spark/shuffle/daos/SizeSamplerAppendOnlyMap.scala b/shuffle-daos/src/main/scala/org/apache/spark/shuffle/daos/SizeSamplerAppendOnlyMap.scala index cb78c5bc..10a69a33 100644 --- a/shuffle-daos/src/main/scala/org/apache/spark/shuffle/daos/SizeSamplerAppendOnlyMap.scala +++ b/shuffle-daos/src/main/scala/org/apache/spark/shuffle/daos/SizeSamplerAppendOnlyMap.scala @@ -29,7 +29,7 @@ private[spark] class SizeSamplerAppendOnlyMap[K, V](val stat: SampleStat) extends AppendOnlyMap[K, V] with SizeSampler { - setSampleStat(stat) + setSampleStat(stat, false) resetSamples() override def update(key: K, value: V): Unit = { diff --git a/shuffle-daos/src/main/scala/org/apache/spark/shuffle/daos/SizeSamplerPairBuffer.scala b/shuffle-daos/src/main/scala/org/apache/spark/shuffle/daos/SizeSamplerPairBuffer.scala index fab21e58..19f98378 100644 --- a/shuffle-daos/src/main/scala/org/apache/spark/shuffle/daos/SizeSamplerPairBuffer.scala +++ b/shuffle-daos/src/main/scala/org/apache/spark/shuffle/daos/SizeSamplerPairBuffer.scala @@ -45,7 +45,7 @@ class SizeSamplerPairBuffer[K, V](val stat: SampleStat, initialCapacity: Int = 6 private var capacity = initialCapacity private var data = new Array[AnyRef](2 * initialCapacity) - setSampleStat(stat) + setSampleStat(stat, true) resetSamples() /** Add an element into the buffer */ @@ -90,6 +90,10 @@ class SizeSamplerPairBuffer[K, V](val stat: SampleStat, initialCapacity: Int = 6 pair } } + + override def size: Int = { + numOfRecords + } } private object SizeSamplerPairBuffer { diff --git a/shuffle-daos/src/main/scala/org/apache/spark/shuffle/daos/package.scala b/shuffle-daos/src/main/scala/org/apache/spark/shuffle/daos/package.scala index 45a804f0..361f15e2 100644 --- a/shuffle-daos/src/main/scala/org/apache/spark/shuffle/daos/package.scala +++ b/shuffle-daos/src/main/scala/org/apache/spark/shuffle/daos/package.scala @@ -32,55 +32,37 @@ package object daos { val SHUFFLE_DAOS_POOL_UUID = ConfigBuilder("spark.shuffle.daos.pool.uuid") - .version("3.0.0") + .version("3.1.1") .stringConf .createWithDefault(null) val SHUFFLE_DAOS_CONTAINER_UUID = ConfigBuilder("spark.shuffle.daos.container.uuid") - .version("3.0.0") + .version("3.1.1") .stringConf .createWithDefault(null) val SHUFFLE_DAOS_REMOVE_SHUFFLE_DATA = ConfigBuilder("spark.shuffle.remove.shuffle.data") .doc("remove shuffle data from DAOS after shuffle completed. Default is true") - .version("3.0.0") + .version("3.1.1") .booleanConf .createWithDefault(true) - val SHUFFLE_DAOS_WRITE_PARTITION_BUFFER_SIZE = - ConfigBuilder("spark.shuffle.daos.write.partition.buffer") - .doc("size of the in-memory buffer for each map partition output, in KiB") - .version("3.0.0") - .bytesConf(ByteUnit.KiB) - .checkValue(v => v > 0, - s"The map partition buffer size must be positive.") - .createWithDefaultString("2048k") - - val SHUFFLE_DAOS_WRITE_BUFFER_SIZE = - ConfigBuilder("spark.shuffle.daos.write.buffer") - .doc("total size of in-memory buffers of each map's all partitions, in MiB") - .version("3.0.0") - .bytesConf(ByteUnit.MiB) - .checkValue(v => v > 50, - s"The total buffer size must be bigger than 50m.") - .createWithDefaultString("800m") - val SHUFFLE_DAOS_WRITE_BUFFER_INITIAL_SIZE = ConfigBuilder("spark.shuffle.daos.write.buffer.initial") .doc("initial size of total in-memory buffer for each map output, in MiB") - .version("3.0.0") + .version("3.1.1") .bytesConf(ByteUnit.MiB) - .checkValue(v => v > 10, - s"The initial total buffer size must be bigger than 10m.") - .createWithDefaultString("80m") + .checkValue(v => v > 0, + s"The initial total buffer size must be bigger than 0.") + .createWithDefaultString("5m") val SHUFFLE_DAOS_WRITE_BUFFER_FORCE_PCT = ConfigBuilder("spark.shuffle.daos.write.buffer.percentage") .doc("percentage of spark.shuffle.daos.buffer. Force write some buffer data out when size is bigger than " + "spark.shuffle.daos.buffer * (this percentage)") - .version("3.0.0") + .version("3.1.1") .doubleConf .checkValue(v => v >= 0.5 && v <= 0.9, s"The percentage must be no less than 0.5 and less than or equal to 0.9") @@ -90,7 +72,7 @@ package object daos { ConfigBuilder("spark.shuffle.daos.write.minimum") .doc("minimum size when write to DAOS, in KiB. A warning will be generated when size is less than this value" + " and spark.shuffle.daos.write.warn.small is set to true") - .version("3.0.0") + .version("3.1.1") .bytesConf(ByteUnit.KiB) .checkValue(v => v > 0, s"The DAOS write minimum size must be positive") @@ -100,32 +82,41 @@ package object daos { ConfigBuilder("spark.shuffle.daos.write.warn.small") .doc("log warning message when the size of written data is smaller than spark.shuffle.daos.write.minimum." + " Default is false") - .version("3.0.0") + .version("3.1.1") .booleanConf .createWithDefault(false) val SHUFFLE_DAOS_WRITE_SINGLE_BUFFER_SIZE = ConfigBuilder("spark.shuffle.daos.write.buffer.single") .doc("size of single buffer for holding data to be written to DAOS") - .version("3.0.0") - .bytesConf(ByteUnit.MiB) - .checkValue(v => v >= 1, - s"The single DAOS write buffer must be at least 1m") - .createWithDefaultString("2m") + .version("3.1.1") + .bytesConf(ByteUnit.KiB) + .checkValue(v => v >= 1 & v <= 10240, + s"The single DAOS write buffer must be at least 1k") + .createWithDefaultString("256k") + + val SHUFFLE_DAOS_WRITE_FLUSH_RECORDS = + ConfigBuilder("spark.shuffle.daos.write.flush.records") + .doc("per how many number of records to flush data in buffer to DAOS") + .version("3.1.1") + .intConf + .checkValue(v => v >= 100, + s"number of records to flush should be more than 100") + .createWithDefault(1000) val SHUFFLE_DAOS_READ_MINIMUM_SIZE = ConfigBuilder("spark.shuffle.daos.read.minimum") .doc("minimum size when read from DAOS, in KiB. ") - .version("3.0.0") + .version("3.1.1") .bytesConf(ByteUnit.KiB) .checkValue(v => v > 0, s"The DAOS read minimum size must be positive") - .createWithDefaultString("2048k") + .createWithDefaultString("128k") val SHUFFLE_DAOS_READ_MAX_BYTES_IN_FLIGHT = ConfigBuilder("spark.shuffle.daos.read.maxbytes.inflight") .doc("maximum size of requested data when read from DAOS, in KiB. ") - .version("3.0.0") + .version("3.1.1") .bytesConf(ByteUnit.KiB) .checkValue(v => v > 0, s"The DAOS read max bytes in flight must be positive") @@ -134,7 +125,7 @@ package object daos { val SHUFFLE_DAOS_WRITE_MAX_BYTES_IN_FLIGHT = ConfigBuilder("spark.shuffle.daos.write.maxbytes.inflight") .doc("maximum size of requested data when write to DAOS, in KiB. ") - .version("3.0.0") + .version("3.1.1") .bytesConf(ByteUnit.KiB) .checkValue(v => v > 0, s"The DAOS write max bytes in flight must be positive") @@ -143,7 +134,7 @@ package object daos { val SHUFFLE_DAOS_IO_ASYNC = ConfigBuilder("spark.shuffle.daos.io.async") .doc("perform shuffle IO asynchronously. Default is true") - .version("3.0.0") + .version("3.1.1") .booleanConf .createWithDefault(true) @@ -151,7 +142,7 @@ package object daos { ConfigBuilder("spark.shuffle.daos.read.threads") .doc("number of threads for each executor to read shuffle data concurrently. -1 means use number of executor " + "cores. sync IO only.") - .version("3.0.0") + .version("3.1.1") .intConf .createWithDefault(1) @@ -159,23 +150,23 @@ package object daos { ConfigBuilder("spark.shuffle.daos.write.threads") .doc("number of threads for each executor to write shuffle data concurrently. -1 means use number of executor " + "cores. sync IO only.") - .version("3.0.0") + .version("3.1.1") .intConf .createWithDefault(1) val SHUFFLE_DAOS_ASYNC_WRITE_BATCH_SIZE = ConfigBuilder("spark.shuffle.daos.async.write.batch") .doc("number of async write before flush") - .version("3.0.0") + .version("3.1.1") .intConf .checkValue(v => v > 0, s"async write batch size must be positive") - .createWithDefault(1) + .createWithDefault(30) val SHUFFLE_DAOS_READ_BATCH_SIZE = ConfigBuilder("spark.shuffle.daos.read.batch") .doc("number of read tasks to submit at most at each time. sync IO only.") - .version("3.0.0") + .version("3.1.1") .intConf .checkValue(v => v > 0, s"read batch size must be positive") @@ -184,45 +175,44 @@ package object daos { val SHUFFLE_DAOS_WRITE_SUBMITTED_LIMIT = ConfigBuilder("spark.shuffle.daos.write.submitted.limit") .doc("limit of number of write tasks to submit. sync IO only.") - .version("3.0.0") + .version("3.1.1") .intConf .checkValue(v => v > 0, s"limit of submitted task must be positive") .createWithDefault(20) + val SHUFFLE_DAOS_WRITE_ASYNC_DESC_CACHES = + ConfigBuilder("spark.shuffle.daos.write.async.desc.caches") + .doc("number of cached I/O description objects for async write.") + .version("3.1.1") + .intConf + .checkValue(v => v >= 0, + s"number of cached I/O description objects must be no less than 0") + .createWithDefault(20) + val SHUFFLE_DAOS_READ_WAIT_MS = ConfigBuilder("spark.shuffle.daos.read.wait.ms") .doc("number of milliseconds to wait data being read before timed out") - .version("3.0.0") + .version("3.1.1") .intConf .checkValue(v => v > 0, s"wait data time must be positive") - .createWithDefault(5000) + .createWithDefault(60000) val SHUFFLE_DAOS_WRITE_WAIT_MS = ConfigBuilder("spark.shuffle.daos.write.wait.ms") .doc("number of milliseconds to wait data being written before timed out") - .version("3.0.0") + .version("3.1.1") .intConf .checkValue(v => v > 0, s"wait data time must be positive") - .createWithDefault(5000) - - val SHUFFLE_DAOS_READ_WAIT_DATA_TIMEOUT_TIMES = - ConfigBuilder("spark.shuffle.daos.read.wait.timeout.times") - .doc("number of wait timeout (spark.shuffle.daos.read.waitdata.ms) after which shuffle read task reads data " + - "by itself instead of dedicated read thread. sync IO only.") - .version("3.0.0") - .intConf - .checkValue(v => v > 0, - s"wait data timeout times must be positive") - .createWithDefault(5) + .createWithDefault(60000) val SHUFFLE_DAOS_WRITE_WAIT_DATA_TIMEOUT_TIMES = ConfigBuilder("spark.shuffle.daos.write.wait.timeout.times") .doc("number of wait timeout (spark.shuffle.daos.write.waitdata.ms) after which shuffle write task fails." + "sync IO only.") - .version("3.0.0") + .version("3.1.1") .intConf .checkValue(v => v > 0, s"wait data timeout times must be positive") @@ -231,21 +221,21 @@ package object daos { val SHUFFLE_DAOS_READ_FROM_OTHER_THREAD = ConfigBuilder("spark.shuffle.daos.read.from.other.threads") .doc("whether read shuffled data from other threads or not. true by default. sync IO only.") - .version("3.0.0") + .version("3.1.1") .booleanConf .createWithDefault(true) val SHUFFLE_DAOS_WRITE_IN_OTHER_THREAD = ConfigBuilder("spark.shuffle.daos.write.in.other.threads") .doc("whether write shuffled data in other threads or not. true by default. sync IO only.") - .version("3.0.0") + .version("3.1.1") .booleanConf .createWithDefault(true) val SHUFFLE_DAOS_WRITE_PARTITION_MOVE_INTERVAL = ConfigBuilder("spark.shuffle.daos.write.partition.move.interval") .doc("move partition at every this interval (number of records). 1000 records by default.") - .version("3.0.0") + .version("3.1.1") .intConf .checkValue(v => v >= 10, "partition move interval should be at least 10.") .createWithDefault(1000) @@ -255,8 +245,39 @@ package object daos { .doc("check total size of partitions and write some partitions at every this interval (number of records)." + " This value should be no less than spark.shuffle.daos.write.partition.move.interval." + " 10000 records by default.") - .version("3.0.0") + .version("3.1.1") .intConf - .checkValue(v => v >= 100, "total interval should be bigger than 100.") - .createWithDefault(10000) + .checkValue(v => v > 0, "total interval should be bigger than 0.") + .createWithDefault(32) + + val SHUFFLE_DAOS_SPILL_FIRST = + ConfigBuilder("spark.shuffle.daos.spill.first") + .doc("When it's true (default), the shuffle manager will try to not spill until granted memory is less than " + + "task heap memory (\"(executor mem - 300) * spark.memory.fraction * cpusPerCore / executor cores\") * " + + "spark.shuffle.daos.spill.grant.pct. The shuffle manager will also spill if there are equal or more than two" + + " consecutive lowly granted memory (granted memory < requested memory). When it's false, the shuffle manager " + + "will spill once there is lowly granted memory.") + .version("3.1.1") + .booleanConf + .createWithDefault(true) + + val SHUFFLE_DAOS_SPILL_GRANT_PCT = + ConfigBuilder("spark.shuffle.daos.spill.grant.pct") + .doc("percentage of task heap memory (\"(executor mem - 300) * spark.memory.fraction * cpusPerCore / executor" + + " cores\"). It takes effect only if spark.shuffle.daos.spill.first is true. When granted memory from" + + " TaskMemoryManager is less than task heap memory * this percentage, spill data to DAOS. Default is 0.1. It " + + "should be less than 0.5.") + .version("3.1.1") + .doubleConf + .checkValue(v => v > 0 & v < 0.5, "spill grant percentage should be greater than 0 and no more" + + " than 0.5 .") + .createWithDefault(0.1) + + val SHUFFLE_DAOS_OBJECT_HINT = + ConfigBuilder("spark.shuffle.daos.object.hint") + .doc("hint of DAOS object class. It's about data redundancy and sharding in DAOS. Check " + + "io.daos.DaosObjClassHint for all available hints.") + .version("3.1.1") + .stringConf + .createWithDefault("DAOS_OCH_SHD_MAX") } diff --git a/shuffle-daos/src/test/java/org/apache/spark/shuffle/daos/DaosReaderAsyncTest.java b/shuffle-daos/src/test/java/org/apache/spark/shuffle/daos/DaosReaderAsyncTest.java index 5ccf029e..54ce778a 100644 --- a/shuffle-daos/src/test/java/org/apache/spark/shuffle/daos/DaosReaderAsyncTest.java +++ b/shuffle-daos/src/test/java/org/apache/spark/shuffle/daos/DaosReaderAsyncTest.java @@ -31,7 +31,6 @@ import org.apache.spark.executor.TempShuffleReadMetrics; import org.apache.spark.shuffle.ShuffleReadMetricsReporter; import org.apache.spark.storage.BlockId; -import org.apache.spark.storage.BlockManagerId; import org.apache.spark.storage.ShuffleBlockId; import org.junit.Assert; import org.junit.Before; @@ -46,7 +45,6 @@ import org.powermock.core.classloader.annotations.SuppressStaticInitializationFor; import org.powermock.modules.junit4.PowerMockRunner; import scala.Tuple2; -import scala.Tuple3; import java.util.LinkedHashMap; import java.util.List; @@ -86,9 +84,9 @@ public void testEmptyData() throws Exception { @Test public void testOneEntry() throws Exception { - LinkedHashMap, Tuple3> partSizeMap; + LinkedHashMap, Tuple2> partSizeMap; partSizeMap = new LinkedHashMap<>(); - long mapId = 12345; + String mapId = "12345"; int reduceId = 6789; long len = 1024; int shuffleId = 1000; @@ -110,9 +108,9 @@ public Object answer(InvocationOnMock invocationOnMock) throws Throwable { list.add(desc); return Integer.valueOf(1); } - }).when(eq).pollCompleted(Mockito.any(), Mockito.anyInt(), Mockito.anyLong()); - BlockId blockId = new ShuffleBlockId(shuffleId, mapId, reduceId); - partSizeMap.put(new Tuple2<>(mapId, reduceId), new Tuple3<>(len, blockId, null)); + }).when(eq).pollCompleted(Mockito.any(), Mockito.any(), Mockito.any(), Mockito.anyInt(), Mockito.anyLong()); + BlockId blockId = new ShuffleBlockId(shuffleId, Long.valueOf(mapId), reduceId); + partSizeMap.put(new Tuple2<>(mapId, reduceId), new Tuple2<>(len, blockId)); ShuffleReadMetricsReporter metrics = new TempShuffleReadMetrics(); reader.prepare(partSizeMap, 2 * 1024 * 1024, 2 * 1024 * 1024, metrics); @@ -120,30 +118,30 @@ public Object answer(InvocationOnMock invocationOnMock) throws Throwable { Assert.assertTrue(realBuf == buf); Assert.assertEquals(len, ((TempShuffleReadMetrics) metrics).remoteBytesRead()); - Mockito.verify(eq).pollCompleted(Mockito.any(), Mockito.anyInt(), Mockito.anyLong()); + Mockito.verify(eq).pollCompleted(Mockito.any(), Mockito.any(), Mockito.any(), Mockito.anyInt(), Mockito.anyLong()); Mockito.verify(object).createAsyncDataDescForFetch(String.valueOf(reduceId), eqHandle); } @Test public void testTwoEntries() throws Exception { - LinkedHashMap, Tuple3> partSizeMap; + LinkedHashMap, Tuple2> partSizeMap; partSizeMap = new LinkedHashMap<>(); - long mapIds[] = new long[] {12345, 12346}; + String[] mapIds = new String[] {"12345", "12346"}; int reduceId = 6789; - long lens[] = new long[] {2 * 1024 * 1024, 1023}; + long[] lens = new long[] {2 * 1024 * 1024, 1023}; int shuffleId = 1000; long eqHandle = 1111L; - IOSimpleDDAsync descs[] = new IOSimpleDDAsync[] {Mockito.mock(IOSimpleDDAsync.class), + IOSimpleDDAsync[] descs = new IOSimpleDDAsync[] {Mockito.mock(IOSimpleDDAsync.class), Mockito.mock(IOSimpleDDAsync.class)}; - IOSimpleDDAsync.AsyncEntry entries[] = new IOSimpleDDAsync.AsyncEntry[] { + IOSimpleDDAsync.AsyncEntry[] entries = new IOSimpleDDAsync.AsyncEntry[] { Mockito.mock(IOSimpleDDAsync.AsyncEntry.class), Mockito.mock(IOSimpleDDAsync.AsyncEntry.class) }; - ByteBuf bufs[] = new ByteBuf[] {Mockito.mock(ByteBuf.class), Mockito.mock(ByteBuf.class)}; - boolean readAlready[] = new boolean[] {false, false}; + ByteBuf[] bufs = new ByteBuf[] {Mockito.mock(ByteBuf.class), Mockito.mock(ByteBuf.class)}; + boolean[] readAlready = new boolean[] {false, false}; for (int i = 0; i < 2; i++) { Mockito.when(entries[i].getFetchedData()).thenReturn(bufs[i]); - Mockito.when(entries[i].getKey()).thenReturn(String.valueOf(mapIds[i])); + Mockito.when(entries[i].getKey()).thenReturn(mapIds[i]); final int index = i; Mockito.doAnswer(new Answer() { @Override @@ -157,12 +155,12 @@ public Object answer(InvocationOnMock invocationOnMock) throws Throwable { }).when(bufs[i]).readableBytes(); Mockito.when(descs[i].getEntry(0)).thenReturn(entries[i]); Mockito.when(descs[i].isSucceeded()).thenReturn(true); - BlockId blockId = new ShuffleBlockId(shuffleId, mapIds[i], reduceId); - partSizeMap.put(new Tuple2<>(mapIds[i], reduceId), new Tuple3<>(lens[i], blockId, null)); + BlockId blockId = new ShuffleBlockId(shuffleId, Long.valueOf(mapIds[i]), reduceId); + partSizeMap.put(new Tuple2<>(mapIds[i], reduceId), new Tuple2<>(lens[i], blockId)); } Mockito.when(eq.getEqWrapperHdl()).thenReturn(eqHandle); - int times[] = new int[] {0}; + int[] times = new int[] {0}; Mockito.doAnswer(new Answer() { @Override public Object answer(InvocationOnMock invocationOnMock) throws Throwable { @@ -174,9 +172,9 @@ public Object answer(InvocationOnMock invocationOnMock) throws Throwable { times[0]++; return Integer.valueOf(1); } - }).when(eq).pollCompleted(Mockito.any(), Mockito.anyInt(), Mockito.anyLong()); + }).when(eq).pollCompleted(Mockito.any(), Mockito.any(), Mockito.any(), Mockito.anyInt(), Mockito.anyLong()); - int times2[] = new int[] {0}; + int[] times2 = new int[] {0}; Mockito.doAnswer(new Answer() { @Override public Object answer(InvocationOnMock invocationOnMock) throws Throwable { @@ -201,7 +199,7 @@ public Object answer(InvocationOnMock invocationOnMock) throws Throwable { } Mockito.verify(eq, Mockito.times(2)) - .pollCompleted(Mockito.any(), Mockito.anyInt(), Mockito.anyLong()); + .pollCompleted(Mockito.any(), Mockito.any(), Mockito.any(), Mockito.anyInt(), Mockito.anyLong()); Mockito.verify(object, Mockito.times(2)) .createAsyncDataDescForFetch(String.valueOf(reduceId), eqHandle); } diff --git a/shuffle-daos/src/test/java/org/apache/spark/shuffle/daos/DaosShuffleIOTest.java b/shuffle-daos/src/test/java/org/apache/spark/shuffle/daos/DaosShuffleIOTest.java index 78664809..2ffeb83a 100644 --- a/shuffle-daos/src/test/java/org/apache/spark/shuffle/daos/DaosShuffleIOTest.java +++ b/shuffle-daos/src/test/java/org/apache/spark/shuffle/daos/DaosShuffleIOTest.java @@ -33,7 +33,6 @@ import org.junit.Test; import org.junit.runner.RunWith; import org.mockito.Mockito; -import org.mockito.internal.stubbing.answers.DoesNothing; import org.powermock.api.mockito.PowerMockito; import org.powermock.core.classloader.annotations.PowerMockIgnore; import org.powermock.core.classloader.annotations.PrepareForTest; @@ -76,11 +75,13 @@ public void testSingleObjectInstanceOpen() throws Exception { DaosObjectId id = PowerMockito.mock(DaosObjectId.class); PowerMockito.whenNew(DaosObjectId.class).withArguments(appId, Long.valueOf(shuffleId)).thenReturn(id); - Mockito.doNothing().when(id).encode(); + Mockito.doNothing().when(id).encode(Mockito.anyLong(), Mockito.any(), Mockito.any(), Mockito.any(), + Mockito.eq(0)); Mockito.when(id.isEncoded()).thenReturn(true); DaosObject daosObject = PowerMockito.mock(DaosObject.class); DaosObjClient client = PowerMockito.mock(DaosObjClient.class); Mockito.when(client.getObject(id)).thenReturn(daosObject); + Mockito.when(daosObject.getOid()).thenReturn(id); AtomicBoolean open = new AtomicBoolean(false); Mockito.when(daosObject.isOpen()).then(invocationOnMock -> @@ -147,8 +148,9 @@ public void testRemoveShuffle() throws Exception { Mockito.doNothing().when(object).punch(); Mockito.doNothing().when(object).close(); objectMap.put(appId + "" + shuffleId, object); + Whitebox.setInternalState(io, "objectMap", objectMap); io.removeShuffle(shuffleId); - Mockito.verify(object); + Mockito.verify(object).punch(); } @Test @@ -164,7 +166,7 @@ public void testKeepShuffledData() throws Exception { Mockito.doThrow(new IllegalStateException("shuffled data should be kept")).when(object).punch(); Mockito.doNothing().when(object).close(); objectMap.put(appId + "" + shuffleId, object); + Whitebox.setInternalState(io, "objectMap", objectMap); io.removeShuffle(shuffleId); - Mockito.verify(object); } } diff --git a/shuffle-daos/src/test/java/org/apache/spark/shuffle/daos/DaosShuffleInputStreamTest.java b/shuffle-daos/src/test/java/org/apache/spark/shuffle/daos/DaosShuffleInputStreamTest.java index e61cb9a9..56870c72 100644 --- a/shuffle-daos/src/test/java/org/apache/spark/shuffle/daos/DaosShuffleInputStreamTest.java +++ b/shuffle-daos/src/test/java/org/apache/spark/shuffle/daos/DaosShuffleInputStreamTest.java @@ -31,7 +31,6 @@ import org.apache.spark.TaskContext; import org.apache.spark.shuffle.ShuffleReadMetricsReporter; import org.apache.spark.storage.BlockId; -import org.apache.spark.storage.BlockManagerId; import org.apache.spark.storage.ShuffleBlockId; import org.junit.AfterClass; import org.junit.Assert; @@ -46,9 +45,7 @@ import org.powermock.modules.junit4.PowerMockRunner; import org.slf4j.Logger; import org.slf4j.LoggerFactory; -import scala.Option; import scala.Tuple2; -import scala.Tuple3; import java.lang.reflect.Constructor; import java.lang.reflect.Method; @@ -156,7 +153,7 @@ public void readFromOtherThreadCancelMultipleTimes(Map ma int expectedFetchTimes = 32; AtomicInteger fetchTimes = new AtomicInteger(0); boolean[] succeeded = new boolean[] {true}; - Method method = IODataDesc.class.getDeclaredMethod("succeed"); + Method method = IODataDescSync.class.getDeclaredMethod("parseFetchResult"); method.setAccessible(true); CountDownLatch latch = new CountDownLatch(expectedFetchTimes); @@ -220,7 +217,7 @@ private void testReadFromOtherThreadCancelOnce(int pos, int desiredOffset, int a AtomicInteger fetchTimes = new AtomicInteger(0); boolean[] succeeded = new boolean[] {true}; AtomicInteger wait = new AtomicInteger(0); - Method method = IODataDesc.class.getDeclaredMethod("succeed"); + Method method = IODataDescSync.class.getDeclaredMethod("parseFetchResult"); method.setAccessible(true); CountDownLatch latch = new CountDownLatch(expectedFetchTimes); @@ -258,14 +255,14 @@ public void testReadSmallMapFromOtherThread() throws Exception { int expectedFetchTimes = 32; AtomicInteger fetchTimes = new AtomicInteger(0); boolean[] succeeded = new boolean[] {true}; - Method method = IODataDesc.class.getDeclaredMethod("succeed"); + Method method = IODataDescSync.class.getDeclaredMethod("parseFetchResult"); method.setAccessible(true); CountDownLatch latch = new CountDownLatch(expectedFetchTimes); Answer answer = (invocationOnMock -> { fetchTimes.getAndIncrement(); - IODataDesc desc = invocationOnMock.getArgument(0); + IODataDescSync desc = invocationOnMock.getArgument(0); desc.encode(); method.invoke(desc); @@ -297,6 +294,7 @@ private void read(int maps, Answer answer, CountDownLatch latch, AtomicInteger fetchTimes, boolean[] succeeded) throws Exception { UserGroupInformation.setLoginUser(UserGroupInformation.createRemoteUser("test")); SparkConf testConf = new SparkConf(false); + testConf.set(package$.MODULE$.SHUFFLE_DAOS_READ_WAIT_MS(), 50); long minSize = 10; testConf.set(package$.MODULE$.SHUFFLE_DAOS_READ_MINIMUM_SIZE(), minSize); SparkContext context = new SparkContext("local", "test", testConf); @@ -317,14 +315,13 @@ private void read(int maps, Answer answer, new DaosReaderSync.ReadThreadFactory()); DaosReaderSync daosReader = new DaosReaderSync(daosObject, new DaosReader.ReaderConfig(testConf), executors.nextExecutor()); - LinkedHashMap, Tuple3> partSizeMap = new LinkedHashMap<>(); + LinkedHashMap, Tuple2> partSizeMap = new LinkedHashMap<>(); int shuffleId = 10; int reduceId = 1; int size = (int)(minSize + 5) * 1024; for (int i = 0; i < maps; i++) { - partSizeMap.put(new Tuple2<>(Long.valueOf(i), 10), new Tuple3<>(Long.valueOf(size), - new ShuffleBlockId(shuffleId, i, reduceId), - BlockManagerId.apply("1", "localhost", 2, Option.empty()))); + partSizeMap.put(new Tuple2<>(String.valueOf(i), 10), new Tuple2<>(Long.valueOf(size), + new ShuffleBlockId(shuffleId, i, reduceId))); } DaosShuffleInputStream is = new DaosShuffleInputStream(daosReader, partSizeMap, 2 * minSize * 1024, @@ -333,12 +330,12 @@ private void read(int maps, Answer answer, // verify cancelled task and continuing submission for (int i = 0; i < maps; i++) { byte[] bytes = new byte[size]; - is.read(bytes); + int readBytes = is.read(bytes); for (int j = 0; j < 255; j++) { try { Assert.assertEquals((byte) j, bytes[j]); } catch (Throwable e) { - LOG.error("error at map " + i + ", loc: " + j); + LOG.error("error at map " + i + ", loc: " + j + ", bytes length: " + readBytes); throw e; } } diff --git a/shuffle-daos/src/test/java/org/apache/spark/shuffle/daos/DaosWriterAsyncTest.java b/shuffle-daos/src/test/java/org/apache/spark/shuffle/daos/DaosWriterAsyncTest.java new file mode 100644 index 00000000..6c704696 --- /dev/null +++ b/shuffle-daos/src/test/java/org/apache/spark/shuffle/daos/DaosWriterAsyncTest.java @@ -0,0 +1,128 @@ +package org.apache.spark.shuffle.daos; + +import io.daos.obj.IODescUpdAsync; +import org.junit.Assert; +import org.junit.Test; +import org.mockito.Mockito; + +import java.util.ArrayList; +import java.util.List; + +public class DaosWriterAsyncTest { + + @Test + public void testDescCacheNotFull() { + DaosWriterAsync.AsyncDescCache cache = new DaosWriterAsync.AsyncDescCache(10) { + public IODescUpdAsync newObject() { + IODescUpdAsync desc = Mockito.mock(IODescUpdAsync.class); + Mockito.when(desc.isReusable()).thenReturn(true); + return desc; + } + }; + List list = new ArrayList<>(); + try { + for (int i = 0; i < 5; i++) { + IODescUpdAsync desc = cache.get(); + Assert.assertTrue(desc.isReusable()); + Assert.assertEquals(i + 1, cache.getIdx()); + list.add(desc); + } + // test reuse + IODescUpdAsync desc = list.remove(0); + cache.put(desc); + Assert.assertEquals(4, cache.getIdx()); + Assert.assertEquals(desc, cache.get()); + cache.put(desc); + for (IODescUpdAsync d : list) { + cache.put(d); + } + Assert.assertEquals(0, cache.getIdx()); + } finally { + cache.release(); + list.forEach(d -> d.release()); + } + } + + @Test + public void testDescCacheFull() { + DaosWriterAsync.AsyncDescCache cache = new DaosWriterAsync.AsyncDescCache(10) { + public IODescUpdAsync newObject() { + IODescUpdAsync desc = Mockito.mock(IODescUpdAsync.class); + Mockito.when(desc.isReusable()).thenReturn(true); + return desc; + } + }; + List list = new ArrayList<>(); + Exception ee = null; + try { + for (int i = 0; i < 11; i++) { + IODescUpdAsync desc = cache.get(); + Assert.assertTrue(desc.isReusable()); + Assert.assertEquals(i + 1, cache.getIdx()); + list.add(desc); + } + } catch (IllegalStateException e) { + ee = e; + } + Assert.assertTrue(ee instanceof IllegalStateException); + Assert.assertTrue(ee.getMessage().contains("cache is full")); + Assert.assertTrue(cache.isFull()); + + try { + // test reuse + IODescUpdAsync desc = list.remove(0); + cache.put(desc); + Assert.assertEquals(9, cache.getIdx()); + Assert.assertEquals(desc, cache.get()); + cache.put(desc); + for (IODescUpdAsync d : list) { + cache.put(d); + } + Assert.assertEquals(0, cache.getIdx()); + desc = cache.get(); + Assert.assertEquals(1, cache.getIdx()); + cache.put(desc); + Assert.assertEquals(0, cache.getIdx()); + } finally { + cache.release(); + list.forEach(d -> d.release()); + } + } + + @Test + public void testDescCachePut() { + DaosWriterAsync.AsyncDescCache cache = new DaosWriterAsync.AsyncDescCache(10) { + public IODescUpdAsync newObject() { + IODescUpdAsync desc = Mockito.mock(IODescUpdAsync.class); + Mockito.when(desc.isReusable()).thenReturn(true); + return desc; + } + }; + List list = new ArrayList<>(); + + for (int i = 0; i < 10; i++) { + IODescUpdAsync desc = cache.get(); + Assert.assertTrue(desc.isReusable()); + Assert.assertEquals(i + 1, cache.getIdx()); + list.add(desc); + } + Assert.assertTrue(cache.isFull()); + + IODescUpdAsync desc = null; + while (!list.isEmpty()) { + desc = list.remove(0); + cache.put(desc); + } + Exception ee = null; + try { + cache.put(desc); + } catch (Exception e) { + ee = e; + } finally { + cache.release(); + list.forEach(d -> d.release()); + } + Assert.assertTrue(ee instanceof IllegalStateException); + Assert.assertTrue(ee.getMessage().contains("more than actual")); + } +} diff --git a/shuffle-daos/src/test/java/org/apache/spark/shuffle/daos/DaosWriterTest.java b/shuffle-daos/src/test/java/org/apache/spark/shuffle/daos/DaosWriterTest.java index cfdaeb94..160aaca6 100644 --- a/shuffle-daos/src/test/java/org/apache/spark/shuffle/daos/DaosWriterTest.java +++ b/shuffle-daos/src/test/java/org/apache/spark/shuffle/daos/DaosWriterTest.java @@ -112,6 +112,7 @@ public void testGetLensWithPartialEmptyPartitions() throws Exception { writer.flush(idx); expectedLens.put(idx, size); } + writer.flushAll(); long[] lens = writer.getPartitionLens(numPart); Assert.assertEquals(numPart, lens.length); for (int i = 0; i < numPart; i++) { @@ -132,10 +133,10 @@ public void testWriteTaskFailed() throws Exception { DaosObject daosObject = Mockito.spy(objectConstructor.newInstance(client, id)); AtomicInteger counter = new AtomicInteger(0); - Method method = IODataDesc.class.getDeclaredMethod("succeed"); + Method method = IODataDescSync.class.getDeclaredMethod("parseUpdateResult"); method.setAccessible(true); Mockito.doAnswer(invoc -> { - IODataDesc desc = invoc.getArgument(0); + IODataDescSync desc = invoc.getArgument(0); desc.encode(); counter.incrementAndGet(); if (counter.get() == 5) { @@ -152,14 +153,14 @@ public void testWriteTaskFailed() throws Exception { .mapId(1) .config(writeConfig); - BoundThreadExecutors executors = new BoundThreadExecutors("read_executors", 1, + BoundThreadExecutors executors = new BoundThreadExecutors("write_executors", 1, new DaosReaderSync.ReadThreadFactory()); DaosWriterSync writer = new DaosWriterSync(daosObject, param, executors.nextExecutor()); for (int i = 0; i < numPart; i++) { writer.write(i, new byte[100]); writer.flush(i); } - + writer.flushAll(); writer.close(); executors.stop(); diff --git a/shuffle-daos/src/test/scala/org/apache/spark/shuffle/daos/DaosShuffleReaderSuite.scala b/shuffle-daos/src/test/scala/org/apache/spark/shuffle/daos/DaosShuffleReaderSuite.scala index 28ad614b..61845504 100644 --- a/shuffle-daos/src/test/scala/org/apache/spark/shuffle/daos/DaosShuffleReaderSuite.scala +++ b/shuffle-daos/src/test/scala/org/apache/spark/shuffle/daos/DaosShuffleReaderSuite.scala @@ -144,7 +144,7 @@ class DaosShuffleReaderSuite extends SparkFunSuite with LocalSparkContext { val mapOutputTracker = mock(classOf[MapOutputTracker]) val localBlockManagerId = BlockManagerId("test-client", "test-client", 1) when(mapOutputTracker.getMapSizesByExecutorId( - shuffleId, reduceId, reduceId + 1)).thenReturn { + shuffleId, reduceId)).thenReturn { // Test a scenario where all data is local, to avoid creating a bunch of additional mocks // for the code to read data over the network. val shuffleBlockIdsAndSizes = (0 until numMaps).map { mapId => @@ -172,7 +172,7 @@ class DaosShuffleReaderSuite extends SparkFunSuite with LocalSparkContext { val taskContext = TaskContext.empty() val metrics = taskContext.taskMetrics.createTempShuffleReadMetrics() val blocksByAddress = mapOutputTracker.getMapSizesByExecutorId( - shuffleId, reduceId, reduceId + 1) + shuffleId, reduceId) val (daosReader, shuffleIO, daosObject) = if (singleCall) { @@ -187,11 +187,11 @@ class DaosShuffleReaderSuite extends SparkFunSuite with LocalSparkContext { val shuffleReader = new DaosShuffleReader[Int, Int]( shuffleHandle, blocksByAddress, + false, taskContext, metrics, shuffleIO, - serializerManager, - false) + serializerManager) assert(shuffleReader.read().length === keyValuePairsPerMap * numMaps) diff --git a/shuffle-daos/src/test/scala/org/apache/spark/shuffle/daos/DaosShuffleWriterPerf.scala b/shuffle-daos/src/test/scala/org/apache/spark/shuffle/daos/DaosShuffleWriterPerf.scala index 2fc2eaed..184178f5 100644 --- a/shuffle-daos/src/test/scala/org/apache/spark/shuffle/daos/DaosShuffleWriterPerf.scala +++ b/shuffle-daos/src/test/scala/org/apache/spark/shuffle/daos/DaosShuffleWriterPerf.scala @@ -26,7 +26,6 @@ package org.apache.spark.shuffle.daos import org.mockito.{Mock, Mockito, MockitoAnnotations} import org.mockito.Answers._ import org.mockito.Mockito.{mock, when} -import org.scalatest.Matchers import scala.collection.mutable import scala.util.Random @@ -36,7 +35,7 @@ import org.apache.spark.serializer.JavaSerializer import org.apache.spark.shuffle.BaseShuffleHandle import org.apache.spark.util.Utils -class DaosShuffleWriterPerf extends SparkFunSuite with SharedSparkContext with Matchers { +class DaosShuffleWriterPerf extends SparkFunSuite with SharedSparkContext { @Mock(answer = RETURNS_SMART_NULLS) private var shuffleIO: DaosShuffleIO = _ @@ -46,12 +45,9 @@ class DaosShuffleWriterPerf extends SparkFunSuite with SharedSparkContext with M private var shuffleHandle: BaseShuffleHandle[Int, Array[Byte], Array[Byte]] = _ private val serializer = new JavaSerializer(conf) - private val singleBufSize = conf.get(SHUFFLE_DAOS_WRITE_SINGLE_BUFFER_SIZE) * 1024 * 1024 + private val singleBufSize = conf.get(SHUFFLE_DAOS_WRITE_SINGLE_BUFFER_SIZE) * 1024 private val minSize = conf.get(SHUFFLE_DAOS_WRITE_MINIMUM_SIZE) * 1024 - conf.set(SHUFFLE_DAOS_WRITE_PARTITION_BUFFER_SIZE, 100L) - conf.set(SHUFFLE_DAOS_WRITE_BUFFER_SIZE, 80L) - override def beforeEach(): Unit = { super.beforeEach() MockitoAnnotations.initMocks(this) diff --git a/shuffle-daos/src/test/scala/org/apache/spark/shuffle/daos/DaosShuffleWriterSuite.scala b/shuffle-daos/src/test/scala/org/apache/spark/shuffle/daos/DaosShuffleWriterSuite.scala index 6a500aca..b4614b99 100644 --- a/shuffle-daos/src/test/scala/org/apache/spark/shuffle/daos/DaosShuffleWriterSuite.scala +++ b/shuffle-daos/src/test/scala/org/apache/spark/shuffle/daos/DaosShuffleWriterSuite.scala @@ -27,7 +27,6 @@ import org.mockito.{Mock, Mockito, MockitoAnnotations} import org.mockito.Answers._ import org.mockito.ArgumentMatchers.{any, anyInt} import org.mockito.Mockito.{mock, never, when} -import org.scalatest.Matchers import scala.collection.mutable import org.apache.spark.{Partitioner, SharedSparkContext, ShuffleDependency, SparkFunSuite} @@ -36,7 +35,7 @@ import org.apache.spark.serializer.JavaSerializer import org.apache.spark.shuffle.BaseShuffleHandle import org.apache.spark.util.Utils -class DaosShuffleWriterSuite extends SparkFunSuite with SharedSparkContext with Matchers { +class DaosShuffleWriterSuite extends SparkFunSuite with SharedSparkContext { @Mock(answer = RETURNS_SMART_NULLS) private var shuffleIO: DaosShuffleIO = _ @@ -46,7 +45,7 @@ class DaosShuffleWriterSuite extends SparkFunSuite with SharedSparkContext with private var shuffleHandle: BaseShuffleHandle[Int, Int, Int] = _ private val serializer = new JavaSerializer(conf) - private val singleBufSize = conf.get(SHUFFLE_DAOS_WRITE_SINGLE_BUFFER_SIZE) * 1024 * 1024 + private val singleBufSize = conf.get(SHUFFLE_DAOS_WRITE_SINGLE_BUFFER_SIZE) * 1024 private val minSize = conf.get(SHUFFLE_DAOS_WRITE_MINIMUM_SIZE) * 1024 override def beforeEach(): Unit = { diff --git a/shuffle-daos/src/test/scala/org/apache/spark/shuffle/daos/SizeSamplerSuite.scala b/shuffle-daos/src/test/scala/org/apache/spark/shuffle/daos/SizeSamplerSuite.scala index 6008ec5b..b2fa80b2 100644 --- a/shuffle-daos/src/test/scala/org/apache/spark/shuffle/daos/SizeSamplerSuite.scala +++ b/shuffle-daos/src/test/scala/org/apache/spark/shuffle/daos/SizeSamplerSuite.scala @@ -31,7 +31,7 @@ class SizeSamplerSuite extends SparkFunSuite { test("test sample AppendOnlyMap by update") { val stat = new SampleStat - var grew = false; + var grew = false val map = new SizeSamplerAppendOnlyMap[Int, Int](stat) { override def growTable(): Unit = { super.growTable() diff --git a/shuffle-daos/src/test/scala/org/apache/spark/shuffle/daos/SizeTest.scala b/shuffle-daos/src/test/scala/org/apache/spark/shuffle/daos/SizeTest.scala index 112fe50c..22441e0a 100644 --- a/shuffle-daos/src/test/scala/org/apache/spark/shuffle/daos/SizeTest.scala +++ b/shuffle-daos/src/test/scala/org/apache/spark/shuffle/daos/SizeTest.scala @@ -41,5 +41,4 @@ class SizeTest { assert(size <= deSize && deSize <= size*1.1) }) } - } diff --git a/shuffle-hadoop/dev/post_results_to_PR.sh b/shuffle-hadoop/dev/post_results_to_PR.sh index 7a2b43a3..6d62b942 100644 --- a/shuffle-hadoop/dev/post_results_to_PR.sh +++ b/shuffle-hadoop/dev/post_results_to_PR.sh @@ -1,5 +1,3 @@ -USERNAME=benchmarker-RemoteShuffle -PASSWORD=$BENCHMARKER_PASSWORD PULL_REQUEST_NUM=$TRAVIS_PULL_REQUEST READ_OR_WRITE=$1 @@ -13,6 +11,8 @@ done echo "$RESULTS" +# setup GITHUB_TOKEN in your environment so that you can access api.github.com without password + message='{"body": "```' message+='\n' message+="$RESULTS" @@ -22,5 +22,5 @@ json_message+='```", "event":"COMMENT"}' echo "$json_message" > benchmark_results.json echo "Sending benchmark requests to PR $PULL_REQUEST_NUM" -curl -XPOST https://${USERNAME}:${PASSWORD}@api.github.com/repos/Intel-bigdata/RemoteShuffle/pulls/${PULL_REQUEST_NUM}/reviews -d @benchmark_results.json +curl -XPOST https://api.github.com/repos/Intel-bigdata/RemoteShuffle/pulls/${PULL_REQUEST_NUM}/reviews -d @benchmark_results.json rm benchmark_results.json diff --git a/shuffle-hadoop/pom.xml b/shuffle-hadoop/pom.xml index 1217f00a..c31749b4 100644 --- a/shuffle-hadoop/pom.xml +++ b/shuffle-hadoop/pom.xml @@ -5,10 +5,10 @@ com.intel.oap remote-shuffle-parent - 1.2.0 + 2.2.0-SNAPSHOT - shuffle-hadoop + remote-shuffle-hadoop OAP Remote Shuffle Based on Hadoop Filesystem jar @@ -31,26 +31,21 @@ org.apache.spark - spark-core_2.12 + spark-core_${scala.binary.version} org.apache.spark - spark-core_2.12 + spark-core_${scala.binary.version} tests test-jar test org.apache.spark - spark-sql_2.12 + spark-sql_${scala.binary.version} ${spark.version} test - - org.apache.hadoop - hadoop-client - 2.7.4 - org.scalatest scalatest_${scala.binary.version} @@ -103,6 +98,7 @@ + ${project.artifactId}-${project.version}-with-spark-${spark.version} net.alchim31.maven diff --git a/shuffle-hadoop/src/main/scala/org/apache/spark/network/netty/RemoteShuffleTransferService.scala b/shuffle-hadoop/src/main/scala/org/apache/spark/network/netty/RemoteShuffleTransferService.scala index c067bbe9..388ea889 100644 --- a/shuffle-hadoop/src/main/scala/org/apache/spark/network/netty/RemoteShuffleTransferService.scala +++ b/shuffle-hadoop/src/main/scala/org/apache/spark/network/netty/RemoteShuffleTransferService.scala @@ -22,9 +22,8 @@ import java.util.{HashMap => JHashMap, Map => JMap} import scala.collection.JavaConverters._ import scala.concurrent.Future import scala.reflect.ClassTag - import com.codahale.metrics.{Metric, MetricSet} - +import org.apache.spark.internal.Logging import org.apache.spark.{SecurityManager, SparkConf, SparkEnv} import org.apache.spark.network.{BlockDataManager, BlockTransferService, TransportContext} import org.apache.spark.network.buffer.ManagedBuffer @@ -46,7 +45,7 @@ private[spark] class RemoteShuffleTransferService( bindAddress: String, override val hostName: String, _port: Int, - numCores: Int) extends BlockTransferService { + numCores: Int) extends BlockTransferService with Logging { // TODO: Don't use Java serialization, use a more cross-version compatible serialization format. private val serializer = new JavaSerializer(conf) diff --git a/shuffle-hadoop/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/shuffle-hadoop/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala index 62d57655..d18e1081 100644 --- a/shuffle-hadoop/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala +++ b/shuffle-hadoop/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala @@ -28,7 +28,6 @@ import scala.collection.mutable import scala.collection.mutable.{HashMap, HashSet, ListBuffer} import scala.concurrent.duration._ import scala.util.control.NonFatal - import org.apache.spark._ import org.apache.spark.broadcast.Broadcast import org.apache.spark.executor.{ExecutorMetrics, TaskMetrics} @@ -37,7 +36,9 @@ import org.apache.spark.internal.config import org.apache.spark.internal.config.Tests.TEST_NO_STAGE_RETRY import org.apache.spark.network.util.JavaUtils import org.apache.spark.partial.{ApproximateActionListener, ApproximateEvaluator, PartialResult} -import org.apache.spark.rdd.{DeterministicLevel, RDD, RDDCheckpointData} +import org.apache.spark.rdd.{RDD, RDDCheckpointData} +import org.apache.spark.resource.ResourceProfile +import org.apache.spark.resource.ResourceProfile.{DEFAULT_RESOURCE_PROFILE_ID, EXECUTOR_CORES_LOCAL_PROPERTY, PYSPARK_MEMORY_LOCAL_PROPERTY} import org.apache.spark.rpc.RpcTimeout import org.apache.spark.shuffle.remote.RemoteShuffleManager import org.apache.spark.storage._ @@ -45,79 +46,79 @@ import org.apache.spark.storage.BlockManagerMessages.BlockManagerHeartbeat import org.apache.spark.util._ /** - * The high-level scheduling layer that implements stage-oriented scheduling. It computes a DAG of - * stages for each job, keeps track of which RDDs and stage outputs are materialized, and finds a - * minimal schedule to run the job. It then submits stages as TaskSets to an underlying - * TaskScheduler implementation that runs them on the cluster. A TaskSet contains fully independent - * tasks that can run right away based on the data that's already on the cluster (e.g. map output - * files from previous stages), though it may fail if this data becomes unavailable. - * - * Spark stages are created by breaking the RDD graph at shuffle boundaries. RDD operations with - * "narrow" dependencies, like map() and filter(), are pipelined together into one set of tasks - * in each stage, but operations with shuffle dependencies require multiple stages (one to write a - * set of map output files, and another to read those files after a barrier). In the end, every - * stage will have only shuffle dependencies on other stages, and may compute multiple operations - * inside it. The actual pipelining of these operations happens in the RDD.compute() functions of - * various RDDs - * - * In addition to coming up with a DAG of stages, the DAGScheduler also determines the preferred - * locations to run each task on, based on the current cache status, and passes these to the - * low-level TaskScheduler. Furthermore, it handles failures due to shuffle output files being - * lost, in which case old stages may need to be resubmitted. Failures *within* a stage that are - * not caused by shuffle file loss are handled by the TaskScheduler, which will retry each task - * a small number of times before cancelling the whole stage. - * - * When looking through this code, there are several key concepts: - * - * - Jobs (represented by [[ActiveJob]]) are the top-level work items submitted to the scheduler. - * For example, when the user calls an action, like count(), a job will be submitted through - * submitJob. Each Job may require the execution of multiple stages to build intermediate data. - * - * - Stages ([[Stage]]) are sets of tasks that compute intermediate results in jobs, where each - * task computes the same function on partitions of the same RDD. Stages are separated at shuffle - * boundaries, which introduce a barrier (where we must wait for the previous stage to finish to - * fetch outputs). There are two types of stages: [[ResultStage]], for the final stage that - * executes an action, and [[ShuffleMapStage]], which writes map output files for a shuffle. - * Stages are often shared across multiple jobs, if these jobs reuse the same RDDs. - * - * - Tasks are individual units of work, each sent to one machine. - * - * - Cache tracking: the DAGScheduler figures out which RDDs are cached to avoid recomputing them - * and likewise remembers which shuffle map stages have already produced output files to avoid - * redoing the map side of a shuffle. - * - * - Preferred locations: the DAGScheduler also computes where to run each task in a stage based - * on the preferred locations of its underlying RDDs, or the location of cached or shuffle data. - * - * - Cleanup: all data structures are cleared when the running jobs that depend on them finish, - * to prevent memory leaks in a long-running application. - * - * To recover from failures, the same stage might need to run multiple times, which are called - * "attempts". If the TaskScheduler reports that a task failed because a map output file from a - * previous stage was lost, the DAGScheduler resubmits that lost stage. This is detected through a - * CompletionEvent with FetchFailed, or an ExecutorLost event. The DAGScheduler will wait a small - * amount of time to see whether other nodes or tasks fail, then resubmit TaskSets for any lost - * stage(s) that compute the missing tasks. As part of this process, we might also have to create - * Stage objects for old (finished) stages where we previously cleaned up the Stage object. Since - * tasks from the old attempt of a stage could still be running, care must be taken to map any - * events received in the correct Stage object. - * - * Here's a checklist to use when making or reviewing changes to this class: - * - * - All data structures should be cleared when the jobs involving them end to avoid indefinite - * accumulation of state in long-running programs. - * - * - When adding a new data structure, update `DAGSchedulerSuite.assertDataStructuresEmpty` to - * include the new structure. This will help to catch memory leaks. - */ + * The high-level scheduling layer that implements stage-oriented scheduling. It computes a DAG of + * stages for each job, keeps track of which RDDs and stage outputs are materialized, and finds a + * minimal schedule to run the job. It then submits stages as TaskSets to an underlying + * TaskScheduler implementation that runs them on the cluster. A TaskSet contains fully independent + * tasks that can run right away based on the data that's already on the cluster (e.g. map output + * files from previous stages), though it may fail if this data becomes unavailable. + * + * Spark stages are created by breaking the RDD graph at shuffle boundaries. RDD operations with + * "narrow" dependencies, like map() and filter(), are pipelined together into one set of tasks + * in each stage, but operations with shuffle dependencies require multiple stages (one to write a + * set of map output files, and another to read those files after a barrier). In the end, every + * stage will have only shuffle dependencies on other stages, and may compute multiple operations + * inside it. The actual pipelining of these operations happens in the RDD.compute() functions of + * various RDDs + * + * In addition to coming up with a DAG of stages, the DAGScheduler also determines the preferred + * locations to run each task on, based on the current cache status, and passes these to the + * low-level TaskScheduler. Furthermore, it handles failures due to shuffle output files being + * lost, in which case old stages may need to be resubmitted. Failures *within* a stage that are + * not caused by shuffle file loss are handled by the TaskScheduler, which will retry each task + * a small number of times before cancelling the whole stage. + * + * When looking through this code, there are several key concepts: + * + * - Jobs (represented by [[ActiveJob]]) are the top-level work items submitted to the scheduler. + * For example, when the user calls an action, like count(), a job will be submitted through + * submitJob. Each Job may require the execution of multiple stages to build intermediate data. + * + * - Stages ([[Stage]]) are sets of tasks that compute intermediate results in jobs, where each + * task computes the same function on partitions of the same RDD. Stages are separated at shuffle + * boundaries, which introduce a barrier (where we must wait for the previous stage to finish to + * fetch outputs). There are two types of stages: [[ResultStage]], for the final stage that + * executes an action, and [[ShuffleMapStage]], which writes map output files for a shuffle. + * Stages are often shared across multiple jobs, if these jobs reuse the same RDDs. + * + * - Tasks are individual units of work, each sent to one machine. + * + * - Cache tracking: the DAGScheduler figures out which RDDs are cached to avoid recomputing them + * and likewise remembers which shuffle map stages have already produced output files to avoid + * redoing the map side of a shuffle. + * + * - Preferred locations: the DAGScheduler also computes where to run each task in a stage based + * on the preferred locations of its underlying RDDs, or the location of cached or shuffle data. + * + * - Cleanup: all data structures are cleared when the running jobs that depend on them finish, + * to prevent memory leaks in a long-running application. + * + * To recover from failures, the same stage might need to run multiple times, which are called + * "attempts". If the TaskScheduler reports that a task failed because a map output file from a + * previous stage was lost, the DAGScheduler resubmits that lost stage. This is detected through a + * CompletionEvent with FetchFailed, or an ExecutorLost event. The DAGScheduler will wait a small + * amount of time to see whether other nodes or tasks fail, then resubmit TaskSets for any lost + * stage(s) that compute the missing tasks. As part of this process, we might also have to create + * Stage objects for old (finished) stages where we previously cleaned up the Stage object. Since + * tasks from the old attempt of a stage could still be running, care must be taken to map any + * events received in the correct Stage object. + * + * Here's a checklist to use when making or reviewing changes to this class: + * + * - All data structures should be cleared when the jobs involving them end to avoid indefinite + * accumulation of state in long-running programs. + * + * - When adding a new data structure, update `DAGSchedulerSuite.assertDataStructuresEmpty` to + * include the new structure. This will help to catch memory leaks. + */ private[spark] class DAGScheduler( - private[scheduler] val sc: SparkContext, - private[scheduler] val taskScheduler: TaskScheduler, - listenerBus: LiveListenerBus, - mapOutputTracker: MapOutputTrackerMaster, - blockManagerMaster: BlockManagerMaster, - env: SparkEnv, - clock: Clock = new SystemClock()) + private[scheduler] val sc: SparkContext, + private[scheduler] val taskScheduler: TaskScheduler, + listenerBus: LiveListenerBus, + mapOutputTracker: MapOutputTrackerMaster, + blockManagerMaster: BlockManagerMaster, + env: SparkEnv, + clock: Clock = new SystemClock()) extends Logging { def this(sc: SparkContext, taskScheduler: TaskScheduler) = { @@ -141,11 +142,11 @@ private[spark] class DAGScheduler( private[scheduler] val jobIdToStageIds = new HashMap[Int, HashSet[Int]] private[scheduler] val stageIdToStage = new HashMap[Int, Stage] /** - * Mapping from shuffle dependency ID to the ShuffleMapStage that will generate the data for - * that dependency. Only includes stages that are part of currently running job (when the job(s) - * that require the shuffle stage complete, the mapping will be removed, and the only record of - * the shuffle data will be in the MapOutputTracker). - */ + * Mapping from shuffle dependency ID to the ShuffleMapStage that will generate the data for + * that dependency. Only includes stages that are part of currently running job (when the job(s) + * that require the shuffle stage complete, the mapping will be removed, and the only record of + * the shuffle data will be in the MapOutputTracker). + */ private[scheduler] val shuffleIdToMapStage = new HashMap[Int, ShuffleMapStage] private[scheduler] val jobIdToActiveJob = new HashMap[Int, ActiveJob] @@ -161,21 +162,42 @@ private[spark] class DAGScheduler( private[scheduler] val activeJobs = new HashSet[ActiveJob] /** - * Contains the locations that each RDD's partitions are cached on. This map's keys are RDD ids - * and its values are arrays indexed by partition numbers. Each array value is the set of - * locations where that RDD partition is cached. - * - * All accesses to this map should be guarded by synchronizing on it (see SPARK-4454). - */ + * Contains the locations that each RDD's partitions are cached on. This map's keys are RDD ids + * and its values are arrays indexed by partition numbers. Each array value is the set of + * locations where that RDD partition is cached. + * + * All accesses to this map should be guarded by synchronizing on it (see SPARK-4454). + */ private val cacheLocs = new HashMap[Int, IndexedSeq[Seq[TaskLocation]]] - // For tracking failed nodes, we use the MapOutputTracker's epoch number, which is sent with - // every task. When we detect a node failing, we note the current epoch number and failed - // executor, increment it for new tasks, and use this to ignore stray ShuffleMapTask results. - // - // TODO: Garbage collect information about failure epochs when we know there are no more - // stray messages to detect. - private val failedEpoch = new HashMap[String, Long] + /** + * Tracks the latest epoch of a fully processed error related to the given executor. (We use + * the MapOutputTracker's epoch number, which is sent with every task.) + * + * When an executor fails, it can affect the results of many tasks, and we have to deal with + * all of them consistently. We don't simply ignore all future results from that executor, + * as the failures may have been transient; but we also don't want to "overreact" to follow- + * on errors we receive. Furthermore, we might receive notification of a task success, after + * we find out the executor has actually failed; we'll assume those successes are, in fact, + * simply delayed notifications and the results have been lost, if the tasks started in the + * same or an earlier epoch. In particular, we use this to control when we tell the + * BlockManagerMaster that the BlockManager has been lost. + */ + private val executorFailureEpoch = new HashMap[String, Long] + + /** + * Tracks the latest epoch of a fully processed error where shuffle files have been lost from + * the given executor. + * + * This is closely related to executorFailureEpoch. They only differ for the executor when + * there is an external shuffle service serving shuffle files and we haven't been notified that + * the entire worker has been lost. In that case, when an executor is lost, we do not update + * the shuffleFileLostEpoch; we wait for a fetch failure. This way, if only the executor + * fails, we do not unregister the shuffle data as it can still be served; but if there is + * a failure in the shuffle service (resulting in fetch failure), we unregister the shuffle + * data only once, even if we get many fetch failures. + */ + private val shuffleFileLostEpoch = new HashMap[String, Long] private [scheduler] val outputCommitCoordinator = env.outputCommitCoordinator @@ -186,36 +208,38 @@ private[spark] class DAGScheduler( /** If enabled, FetchFailed will not cause stage retry, in order to surface the problem. */ private val disallowStageRetryForTest = sc.getConf.get(TEST_NO_STAGE_RETRY) + private val shouldMergeResourceProfiles = sc.getConf.get(config.RESOURCE_PROFILE_MERGE_CONFLICTS) + /** - * Whether to unregister all the outputs on the host in condition that we receive a FetchFailure, - * this is set default to false, which means, we only unregister the outputs related to the exact - * executor(instead of the host) on a FetchFailure. - */ + * Whether to unregister all the outputs on the host in condition that we receive a FetchFailure, + * this is set default to false, which means, we only unregister the outputs related to the exact + * executor(instead of the host) on a FetchFailure. + */ private[scheduler] val unRegisterOutputOnHostOnFetchFailure = sc.getConf.get(config.UNREGISTER_OUTPUT_ON_HOST_ON_FETCH_FAILURE) /** - * Number of consecutive stage attempts allowed before a stage is aborted. - */ + * Number of consecutive stage attempts allowed before a stage is aborted. + */ private[scheduler] val maxConsecutiveStageAttempts = sc.getConf.getInt("spark.stage.maxConsecutiveAttempts", DAGScheduler.DEFAULT_MAX_CONSECUTIVE_STAGE_ATTEMPTS) /** - * Number of max concurrent tasks check failures for each barrier job. - */ + * Number of max concurrent tasks check failures for each barrier job. + */ private[scheduler] val barrierJobIdToNumTasksCheckFailures = new ConcurrentHashMap[Int, Int] /** - * Time in seconds to wait between a max concurrent tasks check failure and the next check. - */ + * Time in seconds to wait between a max concurrent tasks check failure and the next check. + */ private val timeIntervalNumTasksCheck = sc.getConf .get(config.BARRIER_MAX_CONCURRENT_TASKS_CHECK_INTERVAL) /** - * Max number of max concurrent tasks check failures allowed for a job before fail the job - * submission. - */ + * Max number of max concurrent tasks check failures allowed for a job before fail the job + * submission. + */ private val maxFailureNumTasksCheck = sc.getConf .get(config.BARRIER_MAX_CONCURRENT_TASKS_CHECK_MAX_FAILURES) @@ -225,47 +249,49 @@ private[spark] class DAGScheduler( private[spark] val eventProcessLoop = new DAGSchedulerEventProcessLoop(this) taskScheduler.setDAGScheduler(this) + private val pushBasedShuffleEnabled = Utils.isPushBasedShuffleEnabled(sc.getConf) + /** - * Called by the TaskSetManager to report task's starting. - */ + * Called by the TaskSetManager to report task's starting. + */ def taskStarted(task: Task[_], taskInfo: TaskInfo): Unit = { eventProcessLoop.post(BeginEvent(task, taskInfo)) } /** - * Called by the TaskSetManager to report that a task has completed - * and results are being fetched remotely. - */ + * Called by the TaskSetManager to report that a task has completed + * and results are being fetched remotely. + */ def taskGettingResult(taskInfo: TaskInfo): Unit = { eventProcessLoop.post(GettingResultEvent(taskInfo)) } /** - * Called by the TaskSetManager to report task completions or failures. - */ + * Called by the TaskSetManager to report task completions or failures. + */ def taskEnded( - task: Task[_], - reason: TaskEndReason, - result: Any, - accumUpdates: Seq[AccumulatorV2[_, _]], - metricPeaks: Array[Long], - taskInfo: TaskInfo): Unit = { + task: Task[_], + reason: TaskEndReason, + result: Any, + accumUpdates: Seq[AccumulatorV2[_, _]], + metricPeaks: Array[Long], + taskInfo: TaskInfo): Unit = { eventProcessLoop.post( CompletionEvent(task, reason, result, accumUpdates, metricPeaks, taskInfo)) } /** - * Update metrics for in-progress tasks and let the master know that the BlockManager is still - * alive. Return true if the driver knows about the given block manager. Otherwise, return false, - * indicating that the block manager should re-register. - */ + * Update metrics for in-progress tasks and let the master know that the BlockManager is still + * alive. Return true if the driver knows about the given block manager. Otherwise, return false, + * indicating that the block manager should re-register. + */ def executorHeartbeatReceived( - execId: String, - // (taskId, stageId, stageAttemptId, accumUpdates) - accumUpdates: Array[(Long, Int, Int, Seq[AccumulableInfo])], - blockManagerId: BlockManagerId, - // (stageId, stageAttemptId) -> metrics - executorUpdates: mutable.Map[(Int, Int), ExecutorMetrics]): Boolean = { + execId: String, + // (taskId, stageId, stageAttemptId, accumUpdates) + accumUpdates: Array[(Long, Int, Int, Seq[AccumulableInfo])], + blockManagerId: BlockManagerId, + // (stageId, stageAttemptId) -> metrics + executorUpdates: mutable.Map[(Int, Int), ExecutorMetrics]): Boolean = { listenerBus.post(SparkListenerExecutorMetricsUpdate(execId, accumUpdates, executorUpdates)) blockManagerMaster.driverHeartbeatEndPoint.askSync[Boolean]( @@ -273,41 +299,61 @@ private[spark] class DAGScheduler( } /** - * Called by TaskScheduler implementation when an executor fails. - */ + * Called by TaskScheduler implementation when an executor fails. + */ def executorLost(execId: String, reason: ExecutorLossReason): Unit = { eventProcessLoop.post(ExecutorLost(execId, reason)) } /** - * Called by TaskScheduler implementation when a worker is removed. - */ + * Called by TaskScheduler implementation when a worker is removed. + */ def workerRemoved(workerId: String, host: String, message: String): Unit = { eventProcessLoop.post(WorkerRemoved(workerId, host, message)) } /** - * Called by TaskScheduler implementation when a host is added. - */ + * Called by TaskScheduler implementation when a host is added. + */ def executorAdded(execId: String, host: String): Unit = { eventProcessLoop.post(ExecutorAdded(execId, host)) } /** - * Called by the TaskSetManager to cancel an entire TaskSet due to either repeated failures or - * cancellation of the job itself. - */ + * Called by the TaskSetManager to cancel an entire TaskSet due to either repeated failures or + * cancellation of the job itself. + */ def taskSetFailed(taskSet: TaskSet, reason: String, exception: Option[Throwable]): Unit = { eventProcessLoop.post(TaskSetFailed(taskSet, reason, exception)) } /** - * Called by the TaskSetManager when it decides a speculative task is needed. - */ + * Called by the TaskSetManager when it decides a speculative task is needed. + */ def speculativeTaskSubmitted(task: Task[_]): Unit = { eventProcessLoop.post(SpeculativeTaskSubmitted(task)) } + /** + * Called by the TaskSetManager when a taskset becomes unschedulable due to executors being + * excluded because of too many task failures and dynamic allocation is enabled. + */ + def unschedulableTaskSetAdded( + stageId: Int, + stageAttemptId: Int): Unit = { + eventProcessLoop.post(UnschedulableTaskSetAdded(stageId, stageAttemptId)) + } + + /** + * Called by the TaskSetManager when an unschedulable taskset becomes schedulable and dynamic + * allocation is enabled. + */ + def unschedulableTaskSetRemoved( + stageId: Int, + stageAttemptId: Int): Unit = { + eventProcessLoop.post(UnschedulableTaskSetRemoved(stageId, stageAttemptId)) + } + private[scheduler] def getCacheLocs(rdd: RDD[_]): IndexedSeq[Seq[TaskLocation]] = cacheLocs.synchronized { // Note: this doesn't use `getOrElse()` because this method is called O(num tasks) times @@ -332,13 +378,13 @@ private[spark] class DAGScheduler( } /** - * Gets a shuffle map stage if one exists in shuffleIdToMapStage. Otherwise, if the - * shuffle map stage doesn't already exist, this method will create the shuffle map stage in - * addition to any missing ancestor shuffle map stages. - */ + * Gets a shuffle map stage if one exists in shuffleIdToMapStage. Otherwise, if the + * shuffle map stage doesn't already exist, this method will create the shuffle map stage in + * addition to any missing ancestor shuffle map stages. + */ private def getOrCreateShuffleMapStage( - shuffleDep: ShuffleDependency[_, _, _], - firstJobId: Int): ShuffleMapStage = { + shuffleDep: ShuffleDependency[_, _, _], + firstJobId: Int): ShuffleMapStage = { shuffleIdToMapStage.get(shuffleDep.shuffleId) match { case Some(stage) => stage @@ -361,38 +407,41 @@ private[spark] class DAGScheduler( } /** - * Check to make sure we don't launch a barrier stage with unsupported RDD chain pattern. The - * following patterns are not supported: - * 1. Ancestor RDDs that have different number of partitions from the resulting RDD (eg. - * union()/coalesce()/first()/take()/PartitionPruningRDD); - * 2. An RDD that depends on multiple barrier RDDs (eg. barrierRdd1.zip(barrierRdd2)). - */ + * Check to make sure we don't launch a barrier stage with unsupported RDD chain pattern. The + * following patterns are not supported: + * 1. Ancestor RDDs that have different number of partitions from the resulting RDD (e.g. + * union()/coalesce()/first()/take()/PartitionPruningRDD); + * 2. An RDD that depends on multiple barrier RDDs (e.g. barrierRdd1.zip(barrierRdd2)). + */ private def checkBarrierStageWithRDDChainPattern(rdd: RDD[_], numTasksInStage: Int): Unit = { if (rdd.isBarrier() && - !traverseParentRDDsWithinStage(rdd, (r: RDD[_]) => - r.getNumPartitions == numTasksInStage && + !traverseParentRDDsWithinStage(rdd, (r: RDD[_]) => + r.getNumPartitions == numTasksInStage && r.dependencies.count(_.rdd.isBarrier()) <= 1)) { throw new BarrierJobUnsupportedRDDChainException } } /** - * Creates a ShuffleMapStage that generates the given shuffle dependency's partitions. If a - * previously run stage generated the same shuffle data, this function will copy the output - * locations that are still available from the previous shuffle to avoid unnecessarily - * regenerating data. - */ + * Creates a ShuffleMapStage that generates the given shuffle dependency's partitions. If a + * previously run stage generated the same shuffle data, this function will copy the output + * locations that are still available from the previous shuffle to avoid unnecessarily + * regenerating data. + */ def createShuffleMapStage[K, V, C]( - shuffleDep: ShuffleDependency[K, V, C], jobId: Int): ShuffleMapStage = { + shuffleDep: ShuffleDependency[K, V, C], jobId: Int): ShuffleMapStage = { val rdd = shuffleDep.rdd + val (shuffleDeps, resourceProfiles) = getShuffleDependenciesAndResourceProfiles(rdd) + val resourceProfile = mergeResourceProfilesForStage(resourceProfiles) checkBarrierStageWithDynamicAllocation(rdd) - checkBarrierStageWithNumSlots(rdd) + checkBarrierStageWithNumSlots(rdd, resourceProfile) checkBarrierStageWithRDDChainPattern(rdd, rdd.getNumPartitions) val numTasks = rdd.partitions.length - val parents = getOrCreateParentStages(rdd, jobId) + val parents = getOrCreateParentStages(shuffleDeps, jobId) val id = nextStageId.getAndIncrement() val stage = new ShuffleMapStage( - id, rdd, numTasks, parents, jobId, rdd.creationSite, shuffleDep, mapOutputTracker) + id, rdd, numTasks, parents, jobId, rdd.creationSite, shuffleDep, mapOutputTracker, + resourceProfile.id) stageIdToStage(id) = stage shuffleIdToMapStage(shuffleDep.shuffleId) = stage @@ -409,16 +458,16 @@ private[spark] class DAGScheduler( } /** - * We don't support run a barrier stage with dynamic resource allocation enabled, it shall lead - * to some confusing behaviors (eg. with dynamic resource allocation enabled, it may happen that - * we acquire some executors (but not enough to launch all the tasks in a barrier stage) and - * later release them due to executor idle time expire, and then acquire again). - * - * We perform the check on job submit and fail fast if running a barrier stage with dynamic - * resource allocation enabled. - * - * TODO SPARK-24942 Improve cluster resource management with jobs containing barrier stage - */ + * We don't support run a barrier stage with dynamic resource allocation enabled, it shall lead + * to some confusing behaviors (e.g. with dynamic resource allocation enabled, it may happen that + * we acquire some executors (but not enough to launch all the tasks in a barrier stage) and + * later release them due to executor idle time expire, and then acquire again). + * + * We perform the check on job submit and fail fast if running a barrier stage with dynamic + * resource allocation enabled. + * + * TODO SPARK-24942 Improve cluster resource management with jobs containing barrier stage + */ private def checkBarrierStageWithDynamicAllocation(rdd: RDD[_]): Unit = { if (rdd.isBarrier() && Utils.isDynamicAllocationEnabled(sc.getConf)) { throw new BarrierJobRunWithDynamicAllocationException @@ -426,53 +475,115 @@ private[spark] class DAGScheduler( } /** - * Check whether the barrier stage requires more slots (to be able to launch all tasks in the - * barrier stage together) than the total number of active slots currently. Fail current check - * if trying to submit a barrier stage that requires more slots than current total number. If - * the check fails consecutively beyond a configured number for a job, then fail current job - * submission. - */ - private def checkBarrierStageWithNumSlots(rdd: RDD[_]): Unit = { - val numPartitions = rdd.getNumPartitions - val maxNumConcurrentTasks = sc.maxNumConcurrentTasks - if (rdd.isBarrier() && numPartitions > maxNumConcurrentTasks) { - throw new BarrierJobSlotsNumberCheckFailed(numPartitions, maxNumConcurrentTasks) + * Check whether the barrier stage requires more slots (to be able to launch all tasks in the + * barrier stage together) than the total number of active slots currently. Fail current check + * if trying to submit a barrier stage that requires more slots than current total number. If + * the check fails consecutively beyond a configured number for a job, then fail current job + * submission. + */ + private def checkBarrierStageWithNumSlots(rdd: RDD[_], rp: ResourceProfile): Unit = { + if (rdd.isBarrier()) { + val numPartitions = rdd.getNumPartitions + val maxNumConcurrentTasks = sc.maxNumConcurrentTasks(rp) + if (numPartitions > maxNumConcurrentTasks) { + throw new BarrierJobSlotsNumberCheckFailed(numPartitions, maxNumConcurrentTasks) + } } } + private[scheduler] def mergeResourceProfilesForStage( + stageResourceProfiles: HashSet[ResourceProfile]): ResourceProfile = { + logDebug(s"Merging stage rdd profiles: $stageResourceProfiles") + val resourceProfile = if (stageResourceProfiles.size > 1) { + if (shouldMergeResourceProfiles) { + val startResourceProfile = stageResourceProfiles.head + val mergedProfile = stageResourceProfiles.drop(1) + .foldLeft(startResourceProfile)((a, b) => mergeResourceProfiles(a, b)) + // compared merged profile with existing ones so we don't add it over and over again + // if the user runs the same operation multiple times + val resProfile = sc.resourceProfileManager.getEquivalentProfile(mergedProfile) + resProfile match { + case Some(existingRp) => existingRp + case None => + // this ResourceProfile could be different if it was merged so we have to add it to + // our ResourceProfileManager + sc.resourceProfileManager.addResourceProfile(mergedProfile) + mergedProfile + } + } else { + throw new IllegalArgumentException("Multiple ResourceProfiles specified in the RDDs for " + + "this stage, either resolve the conflicting ResourceProfiles yourself or enable " + + s"${config.RESOURCE_PROFILE_MERGE_CONFLICTS.key} and understand how Spark handles " + + "the merging them.") + } + } else { + if (stageResourceProfiles.size == 1) { + stageResourceProfiles.head + } else { + sc.resourceProfileManager.defaultResourceProfile + } + } + resourceProfile + } + + // This is a basic function to merge resource profiles that takes the max + // value of the profiles. We may want to make this more complex in the future as + // you may want to sum some resources (like memory). + private[scheduler] def mergeResourceProfiles( + r1: ResourceProfile, + r2: ResourceProfile): ResourceProfile = { + val mergedExecKeys = r1.executorResources ++ r2.executorResources + val mergedExecReq = mergedExecKeys.map { case (k, v) => + val larger = r1.executorResources.get(k).map( x => + if (x.amount > v.amount) x else v).getOrElse(v) + k -> larger + } + val mergedTaskKeys = r1.taskResources ++ r2.taskResources + val mergedTaskReq = mergedTaskKeys.map { case (k, v) => + val larger = r1.taskResources.get(k).map( x => + if (x.amount > v.amount) x else v).getOrElse(v) + k -> larger + } + new ResourceProfile(mergedExecReq, mergedTaskReq) + } + /** - * Create a ResultStage associated with the provided jobId. - */ + * Create a ResultStage associated with the provided jobId. + */ private def createResultStage( - rdd: RDD[_], - func: (TaskContext, Iterator[_]) => _, - partitions: Array[Int], - jobId: Int, - callSite: CallSite): ResultStage = { + rdd: RDD[_], + func: (TaskContext, Iterator[_]) => _, + partitions: Array[Int], + jobId: Int, + callSite: CallSite): ResultStage = { + val (shuffleDeps, resourceProfiles) = getShuffleDependenciesAndResourceProfiles(rdd) + val resourceProfile = mergeResourceProfilesForStage(resourceProfiles) checkBarrierStageWithDynamicAllocation(rdd) - checkBarrierStageWithNumSlots(rdd) + checkBarrierStageWithNumSlots(rdd, resourceProfile) checkBarrierStageWithRDDChainPattern(rdd, partitions.toSet.size) - val parents = getOrCreateParentStages(rdd, jobId) + val parents = getOrCreateParentStages(shuffleDeps, jobId) val id = nextStageId.getAndIncrement() - val stage = new ResultStage(id, rdd, func, partitions, parents, jobId, callSite) + val stage = new ResultStage(id, rdd, func, partitions, parents, jobId, + callSite, resourceProfile.id) stageIdToStage(id) = stage updateJobIdStageIdMaps(jobId, stage) stage } /** - * Get or create the list of parent stages for a given RDD. The new Stages will be created with - * the provided firstJobId. - */ - private def getOrCreateParentStages(rdd: RDD[_], firstJobId: Int): List[Stage] = { - getShuffleDependencies(rdd).map { shuffleDep => + * Get or create the list of parent stages for the given shuffle dependencies. The new + * Stages will be created with the provided firstJobId. + */ + private def getOrCreateParentStages(shuffleDeps: HashSet[ShuffleDependency[_, _, _]], + firstJobId: Int): List[Stage] = { + shuffleDeps.map { shuffleDep => getOrCreateShuffleMapStage(shuffleDep, firstJobId) }.toList } /** Find ancestor shuffle dependencies that are not registered in shuffleToMapStage yet */ private def getMissingAncestorShuffleDependencies( - rdd: RDD[_]): ListBuffer[ShuffleDependency[_, _, _]] = { + rdd: RDD[_]): ListBuffer[ShuffleDependency[_, _, _]] = { val ancestors = new ListBuffer[ShuffleDependency[_, _, _]] val visited = new HashSet[RDD[_]] // We are manually maintaining a stack here to prevent StackOverflowError @@ -483,7 +594,8 @@ private[spark] class DAGScheduler( val toVisit = waitingForVisit.remove(0) if (!visited(toVisit)) { visited += toVisit - getShuffleDependencies(toVisit).foreach { shuffleDep => + val (shuffleDeps, _) = getShuffleDependenciesAndResourceProfiles(toVisit) + shuffleDeps.foreach { shuffleDep => if (!shuffleIdToMapStage.contains(shuffleDep.shuffleId)) { ancestors.prepend(shuffleDep) waitingForVisit.prepend(shuffleDep.rdd) @@ -495,20 +607,22 @@ private[spark] class DAGScheduler( } /** - * Returns shuffle dependencies that are immediate parents of the given RDD. - * - * This function will not return more distant ancestors. For example, if C has a shuffle - * dependency on B which has a shuffle dependency on A: - * - * A <-- B <-- C - * - * calling this function with rdd C will only return the B <-- C dependency. - * - * This function is scheduler-visible for the purpose of unit testing. - */ - private[scheduler] def getShuffleDependencies( - rdd: RDD[_]): HashSet[ShuffleDependency[_, _, _]] = { + * Returns shuffle dependencies that are immediate parents of the given RDD and the + * ResourceProfiles associated with the RDDs for this stage. + * + * This function will not return more distant ancestors for shuffle dependencies. For example, + * if C has a shuffle dependency on B which has a shuffle dependency on A: + * + * A <-- B <-- C + * + * calling this function with rdd C will only return the B <-- C dependency. + * + * This function is scheduler-visible for the purpose of unit testing. + */ + private[scheduler] def getShuffleDependenciesAndResourceProfiles( + rdd: RDD[_]): (HashSet[ShuffleDependency[_, _, _]], HashSet[ResourceProfile]) = { val parents = new HashSet[ShuffleDependency[_, _, _]] + val resourceProfiles = new HashSet[ResourceProfile] val visited = new HashSet[RDD[_]] val waitingForVisit = new ListBuffer[RDD[_]] waitingForVisit += rdd @@ -516,6 +630,7 @@ private[spark] class DAGScheduler( val toVisit = waitingForVisit.remove(0) if (!visited(toVisit)) { visited += toVisit + Option(toVisit.getResourceProfile).foreach(resourceProfiles += _) toVisit.dependencies.foreach { case shuffleDep: ShuffleDependency[_, _, _] => parents += shuffleDep @@ -524,13 +639,13 @@ private[spark] class DAGScheduler( } } } - parents + (parents, resourceProfiles) } /** - * Traverses the given RDD and its ancestors within the same stage and checks whether all of the - * RDDs satisfy a given predicate. - */ + * Traverses the given RDD and its ancestors within the same stage and checks whether all of the + * RDDs satisfy a given predicate. + */ private def traverseParentRDDsWithinStage(rdd: RDD[_], predicate: RDD[_] => Boolean): Boolean = { val visited = new HashSet[RDD[_]] val waitingForVisit = new ListBuffer[RDD[_]] @@ -544,7 +659,7 @@ private[spark] class DAGScheduler( visited += toVisit toVisit.dependencies.foreach { case _: ShuffleDependency[_, _, _] => - // Not within the same stage with current rdd, do nothing. + // Not within the same stage with current rdd, do nothing. case dependency => waitingForVisit.prepend(dependency.rdd) } @@ -586,9 +701,9 @@ private[spark] class DAGScheduler( } /** - * Registers the given jobId among the jobs that need the given stage and - * all of that stage's ancestors. - */ + * Registers the given jobId among the jobs that need the given stage and + * all of that stage's ancestors. + */ private def updateJobIdStageIdMaps(jobId: Int, stage: Stage): Unit = { @tailrec def updateJobIdStageIdMapsList(stages: List[Stage]): Unit = { @@ -604,11 +719,11 @@ private[spark] class DAGScheduler( } /** - * Removes state for job and any stages that are not needed by any other job. Does not - * handle cancelling tasks or notifying the SparkListener about finished jobs/stages/tasks. - * - * @param job The job whose state to cleanup. - */ + * Removes state for job and any stages that are not needed by any other job. Does not + * handle cancelling tasks or notifying the SparkListener about finished jobs/stages/tasks. + * + * @param job The job whose state to cleanup. + */ private def cleanupStateForJobAndIndependentStages(job: ActiveJob): Unit = { val registeredStages = jobIdToStageIds.get(job.jobId) if (registeredStages.isEmpty || registeredStages.get.isEmpty) { @@ -620,7 +735,7 @@ private[spark] class DAGScheduler( if (!jobSet.contains(job.jobId)) { logError( "Job %d not registered for stage %d even though that stage was registered for the job" - .format(job.jobId, stageId)) + .format(job.jobId, stageId)) } else { def removeStage(stageId: Int): Unit = { // data structures based on Stage @@ -664,28 +779,28 @@ private[spark] class DAGScheduler( } /** - * Submit an action job to the scheduler. - * - * @param rdd target RDD to run tasks on - * @param func a function to run on each partition of the RDD - * @param partitions set of partitions to run on; some jobs may not want to compute on all - * partitions of the target RDD, e.g. for operations like first() - * @param callSite where in the user program this job was called - * @param resultHandler callback to pass each result to - * @param properties scheduler properties to attach to this job, e.g. fair scheduler pool name - * - * @return a JobWaiter object that can be used to block until the job finishes executing - * or can be used to cancel the job. - * - * @throws IllegalArgumentException when partitions ids are illegal - */ + * Submit an action job to the scheduler. + * + * @param rdd target RDD to run tasks on + * @param func a function to run on each partition of the RDD + * @param partitions set of partitions to run on; some jobs may not want to compute on all + * partitions of the target RDD, e.g. for operations like first() + * @param callSite where in the user program this job was called + * @param resultHandler callback to pass each result to + * @param properties scheduler properties to attach to this job, e.g. fair scheduler pool name + * + * @return a JobWaiter object that can be used to block until the job finishes executing + * or can be used to cancel the job. + * + * @throws IllegalArgumentException when partitions ids are illegal + */ def submitJob[T, U]( - rdd: RDD[T], - func: (TaskContext, Iterator[T]) => U, - partitions: Seq[Int], - callSite: CallSite, - resultHandler: (Int, U) => Unit, - properties: Properties): JobWaiter[U] = { + rdd: RDD[T], + func: (TaskContext, Iterator[T]) => U, + partitions: Seq[Int], + callSite: CallSite, + resultHandler: (Int, U) => Unit, + properties: Properties): JobWaiter[U] = { // Check to make sure we are not launching a task on a partition that does not exist. val maxPartitions = rdd.partitions.length partitions.find(p => p >= maxPartitions || p < 0).foreach { p => @@ -719,36 +834,36 @@ private[spark] class DAGScheduler( } /** - * Run an action job on the given RDD and pass all the results to the resultHandler function as - * they arrive. - * - * @param rdd target RDD to run tasks on - * @param func a function to run on each partition of the RDD - * @param partitions set of partitions to run on; some jobs may not want to compute on all - * partitions of the target RDD, e.g. for operations like first() - * @param callSite where in the user program this job was called - * @param resultHandler callback to pass each result to - * @param properties scheduler properties to attach to this job, e.g. fair scheduler pool name - * - * @note Throws `Exception` when the job fails - */ + * Run an action job on the given RDD and pass all the results to the resultHandler function as + * they arrive. + * + * @param rdd target RDD to run tasks on + * @param func a function to run on each partition of the RDD + * @param partitions set of partitions to run on; some jobs may not want to compute on all + * partitions of the target RDD, e.g. for operations like first() + * @param callSite where in the user program this job was called + * @param resultHandler callback to pass each result to + * @param properties scheduler properties to attach to this job, e.g. fair scheduler pool name + * + * @note Throws `Exception` when the job fails + */ def runJob[T, U]( - rdd: RDD[T], - func: (TaskContext, Iterator[T]) => U, - partitions: Seq[Int], - callSite: CallSite, - resultHandler: (Int, U) => Unit, - properties: Properties): Unit = { + rdd: RDD[T], + func: (TaskContext, Iterator[T]) => U, + partitions: Seq[Int], + callSite: CallSite, + resultHandler: (Int, U) => Unit, + properties: Properties): Unit = { val start = System.nanoTime val waiter = submitJob(rdd, func, partitions, callSite, resultHandler, properties) ThreadUtils.awaitReady(waiter.completionFuture, Duration.Inf) waiter.completionFuture.value.get match { case scala.util.Success(_) => logInfo("Job %d finished: %s, took %f s".format - (waiter.jobId, callSite.shortForm, (System.nanoTime - start) / 1e9)) + (waiter.jobId, callSite.shortForm, (System.nanoTime - start) / 1e9)) case scala.util.Failure(exception) => logInfo("Job %d failed: %s, took %f s".format - (waiter.jobId, callSite.shortForm, (System.nanoTime - start) / 1e9)) + (waiter.jobId, callSite.shortForm, (System.nanoTime - start) / 1e9)) // SPARK-8644: Include user stack trace in exceptions coming from DAGScheduler. val callerStackTrace = Thread.currentThread().getStackTrace.tail exception.setStackTrace(exception.getStackTrace ++ callerStackTrace) @@ -757,28 +872,29 @@ private[spark] class DAGScheduler( } /** - * Run an approximate job on the given RDD and pass all the results to an ApproximateEvaluator - * as they arrive. Returns a partial result object from the evaluator. - * - * @param rdd target RDD to run tasks on - * @param func a function to run on each partition of the RDD - * @param evaluator `ApproximateEvaluator` to receive the partial results - * @param callSite where in the user program this job was called - * @param timeout maximum time to wait for the job, in milliseconds - * @param properties scheduler properties to attach to this job, e.g. fair scheduler pool name - */ + * Run an approximate job on the given RDD and pass all the results to an ApproximateEvaluator + * as they arrive. Returns a partial result object from the evaluator. + * + * @param rdd target RDD to run tasks on + * @param func a function to run on each partition of the RDD + * @param evaluator `ApproximateEvaluator` to receive the partial results + * @param callSite where in the user program this job was called + * @param timeout maximum time to wait for the job, in milliseconds + * @param properties scheduler properties to attach to this job, e.g. fair scheduler pool name + */ def runApproximateJob[T, U, R]( - rdd: RDD[T], - func: (TaskContext, Iterator[T]) => U, - evaluator: ApproximateEvaluator[U, R], - callSite: CallSite, - timeout: Long, - properties: Properties): PartialResult[R] = { + rdd: RDD[T], + func: (TaskContext, Iterator[T]) => U, + evaluator: ApproximateEvaluator[U, R], + callSite: CallSite, + timeout: Long, + properties: Properties): PartialResult[R] = { val jobId = nextJobId.getAndIncrement() + val clonedProperties = Utils.cloneProperties(properties) if (rdd.partitions.isEmpty) { // Return immediately if the job is running 0 tasks val time = clock.getTimeMillis() - listenerBus.post(SparkListenerJobStart(jobId, time, Seq[StageInfo](), properties)) + listenerBus.post(SparkListenerJobStart(jobId, time, Seq[StageInfo](), clonedProperties)) listenerBus.post(SparkListenerJobEnd(jobId, time, JobSucceeded)) return new PartialResult(evaluator.currentResult(), true) } @@ -786,27 +902,27 @@ private[spark] class DAGScheduler( val func2 = func.asInstanceOf[(TaskContext, Iterator[_]) => _] eventProcessLoop.post(JobSubmitted( jobId, rdd, func2, rdd.partitions.indices.toArray, callSite, listener, - Utils.cloneProperties(properties))) + clonedProperties)) listener.awaitResult() // Will throw an exception if the job fails } /** - * Submit a shuffle map stage to run independently and get a JobWaiter object back. The waiter - * can be used to block until the job finishes executing or can be used to cancel the job. - * This method is used for adaptive query planning, to run map stages and look at statistics - * about their outputs before submitting downstream stages. - * - * @param dependency the ShuffleDependency to run a map stage for - * @param callback function called with the result of the job, which in this case will be a - * single MapOutputStatistics object showing how much data was produced for each partition - * @param callSite where in the user program this job was submitted - * @param properties scheduler properties to attach to this job, e.g. fair scheduler pool name - */ + * Submit a shuffle map stage to run independently and get a JobWaiter object back. The waiter + * can be used to block until the job finishes executing or can be used to cancel the job. + * This method is used for adaptive query planning, to run map stages and look at statistics + * about their outputs before submitting downstream stages. + * + * @param dependency the ShuffleDependency to run a map stage for + * @param callback function called with the result of the job, which in this case will be a + * single MapOutputStatistics object showing how much data was produced for each partition + * @param callSite where in the user program this job was submitted + * @param properties scheduler properties to attach to this job, e.g. fair scheduler pool name + */ def submitMapStage[K, V, C]( - dependency: ShuffleDependency[K, V, C], - callback: MapOutputStatistics => Unit, - callSite: CallSite, - properties: Properties): JobWaiter[MapOutputStatistics] = { + dependency: ShuffleDependency[K, V, C], + callback: MapOutputStatistics => Unit, + callSite: CallSite, + properties: Properties): JobWaiter[MapOutputStatistics] = { val rdd = dependency.rdd val jobId = nextJobId.getAndIncrement() @@ -828,24 +944,24 @@ private[spark] class DAGScheduler( } /** - * Cancel a job that is running or waiting in the queue. - */ + * Cancel a job that is running or waiting in the queue. + */ def cancelJob(jobId: Int, reason: Option[String]): Unit = { logInfo("Asked to cancel job " + jobId) eventProcessLoop.post(JobCancelled(jobId, reason)) } /** - * Cancel all jobs in the given job group ID. - */ + * Cancel all jobs in the given job group ID. + */ def cancelJobGroup(groupId: String): Unit = { logInfo("Asked to cancel job group " + groupId) eventProcessLoop.post(JobGroupCancelled(groupId)) } /** - * Cancel all jobs that are running or waiting in the queue. - */ + * Cancel all jobs that are running or waiting in the queue. + */ def cancelAllJobs(): Unit = { eventProcessLoop.post(AllJobsCancelled) } @@ -859,25 +975,25 @@ private[spark] class DAGScheduler( } /** - * Cancel all jobs associated with a running or scheduled stage. - */ + * Cancel all jobs associated with a running or scheduled stage. + */ def cancelStage(stageId: Int, reason: Option[String]): Unit = { eventProcessLoop.post(StageCancelled(stageId, reason)) } /** - * Kill a given task. It will be retried. - * - * @return Whether the task was successfully killed. - */ + * Kill a given task. It will be retried. + * + * @return Whether the task was successfully killed. + */ def killTaskAttempt(taskId: Long, interruptThread: Boolean, reason: String): Boolean = { taskScheduler.killTaskAttempt(taskId, interruptThread, reason) } /** - * Resubmit any failed stages. Ordinarily called after a small amount of time has passed since - * the last fetch failure. - */ + * Resubmit any failed stages. Ordinarily called after a small amount of time has passed since + * the last fetch failure. + */ private[scheduler] def resubmitFailedStages(): Unit = { if (failedStages.nonEmpty) { // Failed stages may be removed by job cancellation, so failed might be empty even if @@ -893,10 +1009,10 @@ private[spark] class DAGScheduler( } /** - * Check for waiting stages which are now eligible for resubmission. - * Submits stages that depend on the given parent stage. Called when the parent stage completes - * successfully. - */ + * Check for waiting stages which are now eligible for resubmission. + * Submits stages that depend on the given parent stage. Called when the parent stage completes + * successfully. + */ private def submitWaitingChildStages(parent: Stage): Unit = { logTrace(s"Checking if any dependencies of $parent are now runnable") logTrace("running: " + runningStages) @@ -929,14 +1045,14 @@ private[spark] class DAGScheduler( } val jobIds = activeInGroup.map(_.jobId) jobIds.foreach(handleJobCancellation(_, - Option("part of cancelled job group %s".format(groupId)))) + Option("part of cancelled job group %s".format(groupId)))) } private[scheduler] def handleBeginEvent(task: Task[_], taskInfo: TaskInfo): Unit = { // Note that there is a chance that this task is launched after the stage is cancelled. // In that case, we wouldn't have the stage anymore in stageIdToStage. val stageAttemptId = - stageIdToStage.get(task.stageId).map(_.latestInfo.attemptNumber).getOrElse(-1) + stageIdToStage.get(task.stageId).map(_.latestInfo.attemptNumber).getOrElse(-1) listenerBus.post(SparkListenerTaskStart(task.stageId, stageAttemptId, taskInfo)) } @@ -944,10 +1060,22 @@ private[spark] class DAGScheduler( listenerBus.post(SparkListenerSpeculativeTaskSubmitted(task.stageId, task.stageAttemptId)) } + private[scheduler] def handleUnschedulableTaskSetAdded( + stageId: Int, + stageAttemptId: Int): Unit = { + listenerBus.post(SparkListenerUnschedulableTaskSetAdded(stageId, stageAttemptId)) + } + + private[scheduler] def handleUnschedulableTaskSetRemoved( + stageId: Int, + stageAttemptId: Int): Unit = { + listenerBus.post(SparkListenerUnschedulableTaskSetRemoved(stageId, stageAttemptId)) + } + private[scheduler] def handleTaskSetFailed( - taskSet: TaskSet, - reason: String, - exception: Option[Throwable]): Unit = { + taskSet: TaskSet, + reason: String, + exception: Option[Throwable]): Unit = { stageIdToStage.get(taskSet.stageId).foreach { abortStage(_, reason, exception) } } @@ -974,12 +1102,12 @@ private[spark] class DAGScheduler( } private[scheduler] def handleJobSubmitted(jobId: Int, - finalRDD: RDD[_], - func: (TaskContext, Iterator[_]) => _, - partitions: Array[Int], - callSite: CallSite, - listener: JobListener, - properties: Properties): Unit = { + finalRDD: RDD[_], + func: (TaskContext, Iterator[_]) => _, + partitions: Array[Int], + callSite: CallSite, + listener: JobListener, + properties: Properties): Unit = { var finalStage: ResultStage = null try { // New stage creation may throw an exception if, for example, jobs are run on a @@ -1035,15 +1163,16 @@ private[spark] class DAGScheduler( val stageIds = jobIdToStageIds(jobId).toArray val stageInfos = stageIds.flatMap(id => stageIdToStage.get(id).map(_.latestInfo)) listenerBus.post( - SparkListenerJobStart(job.jobId, jobSubmissionTime, stageInfos, properties)) + SparkListenerJobStart(job.jobId, jobSubmissionTime, stageInfos, + Utils.cloneProperties(properties))) submitStage(finalStage) } private[scheduler] def handleMapStageSubmitted(jobId: Int, - dependency: ShuffleDependency[_, _, _], - callSite: CallSite, - listener: JobListener, - properties: Properties): Unit = { + dependency: ShuffleDependency[_, _, _], + callSite: CallSite, + listener: JobListener, + properties: Properties): Unit = { // Submitting this map stage might still require the creation of some parent stages, so make // sure that happens. var finalStage: ShuffleMapStage = null @@ -1073,7 +1202,8 @@ private[spark] class DAGScheduler( val stageIds = jobIdToStageIds(jobId).toArray val stageInfos = stageIds.flatMap(id => stageIdToStage.get(id).map(_.latestInfo)) listenerBus.post( - SparkListenerJobStart(job.jobId, jobSubmissionTime, stageInfos, properties)) + SparkListenerJobStart(job.jobId, jobSubmissionTime, stageInfos, + Utils.cloneProperties(properties))) submitStage(finalStage) // If the whole stage has already finished, tell the listener and remove it @@ -1106,6 +1236,54 @@ private[spark] class DAGScheduler( } } + /** + * `PythonRunner` needs to know what the pyspark memory and cores settings are for the profile + * being run. Pass them in the local properties of the task if it's set for the stage profile. + */ + private def addPySparkConfigsToProperties(stage: Stage, properties: Properties): Unit = { + val rp = sc.resourceProfileManager.resourceProfileFromId(stage.resourceProfileId) + val pysparkMem = rp.getPySparkMemory + // use the getOption on EXECUTOR_CORES.key instead of using the EXECUTOR_CORES config reader + // because the default for this config isn't correct for standalone mode. Here we want + // to know if it was explicitly set or not. The default profile always has it set to either + // what user specified or default so special case it here. + val execCores = if (rp.id == DEFAULT_RESOURCE_PROFILE_ID) { + sc.conf.getOption(config.EXECUTOR_CORES.key) + } else { + val profCores = rp.getExecutorCores.map(_.toString) + if (profCores.isEmpty) sc.conf.getOption(config.EXECUTOR_CORES.key) else profCores + } + pysparkMem.map(mem => properties.setProperty(PYSPARK_MEMORY_LOCAL_PROPERTY, mem.toString)) + execCores.map(cores => properties.setProperty(EXECUTOR_CORES_LOCAL_PROPERTY, cores)) + } + + /** + * If push based shuffle is enabled, set the shuffle services to be used for the given + * shuffle map stage for block push/merge. + * + * Even with dynamic resource allocation kicking in and significantly reducing the number + * of available active executors, we would still be able to get sufficient shuffle service + * locations for block push/merge by getting the historical locations of past executors. + */ + private def prepareShuffleServicesForShuffleMapStage(stage: ShuffleMapStage): Unit = { + // TODO(SPARK-32920) Handle stage reuse/retry cases separately as without finalize + // TODO changes we cannot disable shuffle merge for the retry/reuse cases + val mergerLocs = sc.schedulerBackend.getShufflePushMergerLocations( + stage.shuffleDep.partitioner.numPartitions, stage.resourceProfileId) + + if (mergerLocs.nonEmpty) { + stage.shuffleDep.setMergerLocs(mergerLocs) + logInfo(s"Push-based shuffle enabled for $stage (${stage.name}) with" + + s" ${stage.shuffleDep.getMergerLocs.size} merger locations") + + logDebug("List of shuffle push merger locations " + + s"${stage.shuffleDep.getMergerLocs.map(_.host).mkString(", ")}") + } else { + logInfo("No available merger locations." + + s" Push-based shuffle disabled for $stage (${stage.name})") + } + } + /** Called when stage's parents are available and we can now do its task. */ private def submitMissingTasks(stage: Stage, jobId: Int): Unit = { logDebug("submitMissingTasks(" + stage + ")") @@ -1125,6 +1303,7 @@ private[spark] class DAGScheduler( // Use the scheduling pool, job group, description, etc. from an ActiveJob associated // with this Stage val properties = jobIdToActiveJob(jobId).properties + addPySparkConfigsToProperties(stage, properties) runningStages += stage // SparkListenerStageSubmitted should be posted before testing whether tasks are @@ -1134,6 +1313,12 @@ private[spark] class DAGScheduler( stage match { case s: ShuffleMapStage => outputCommitCoordinator.stageStart(stage = s.id, maxPartitionId = s.numPartitions - 1) + // Only generate merger location for a given shuffle dependency once. This way, even if + // this stage gets retried, it would still be merging blocks using the same set of + // shuffle services. + if (pushBasedShuffleEnabled) { + prepareShuffleServicesForShuffleMapStage(s) + } case s: ResultStage => outputCommitCoordinator.stageStart( stage = s.id, maxPartitionId = s.rdd.partitions.length - 1) @@ -1151,7 +1336,8 @@ private[spark] class DAGScheduler( } catch { case NonFatal(e) => stage.makeNewStageAttempt(partitionsToCompute.size) - listenerBus.post(SparkListenerStageSubmitted(stage.latestInfo, properties)) + listenerBus.post(SparkListenerStageSubmitted(stage.latestInfo, + Utils.cloneProperties(properties))) abortStage(stage, s"Task creation failed: $e\n${Utils.exceptionString(e)}", Some(e)) runningStages -= stage return @@ -1165,7 +1351,8 @@ private[spark] class DAGScheduler( if (partitionsToCompute.nonEmpty) { stage.latestInfo.submissionTime = Some(clock.getTimeMillis()) } - listenerBus.post(SparkListenerStageSubmitted(stage.latestInfo, properties)) + listenerBus.post(SparkListenerStageSubmitted(stage.latestInfo, + Utils.cloneProperties(properties))) // TODO: Maybe we can keep the taskBinary in Stage to avoid serializing it multiple times. // Broadcasted binary for the task, used to dispatch tasks to executors. Note that we broadcast @@ -1251,7 +1438,8 @@ private[spark] class DAGScheduler( logInfo(s"Submitting ${tasks.size} missing tasks from $stage (${stage.rdd}) (first 15 " + s"tasks are for partitions ${tasks.take(15).map(_.partitionId)})") taskScheduler.submitTasks(new TaskSet( - tasks.toArray, stage.id, stage.latestInfo.attemptNumber, jobId, properties)) + tasks.toArray, stage.id, stage.latestInfo.attemptNumber, jobId, properties, + stage.resourceProfileId)) } else { // Because we posted SparkListenerStageSubmitted earlier, we should mark // the stage as completed here in case there are no tasks to run @@ -1260,9 +1448,9 @@ private[spark] class DAGScheduler( stage match { case stage: ShuffleMapStage => logDebug(s"Stage ${stage} is actually done; " + - s"(available: ${stage.isAvailable}," + - s"available outputs: ${stage.numAvailableOutputs}," + - s"partitions: ${stage.numPartitions})") + s"(available: ${stage.isAvailable}," + + s"available outputs: ${stage.numAvailableOutputs}," + + s"partitions: ${stage.numPartitions})") markMapStageJobsAsFinished(stage) case stage : ResultStage => logDebug(s"Stage ${stage} is actually done; (partitions: ${stage.numPartitions})") @@ -1272,15 +1460,15 @@ private[spark] class DAGScheduler( } /** - * Merge local values from a task into the corresponding accumulators previously registered - * here on the driver. - * - * Although accumulators themselves are not thread-safe, this method is called only from one - * thread, the one that runs the scheduling loop. This means we only handle one task - * completion event at a time so we don't need to worry about locking the accumulators. - * This still doesn't stop the caller from updating the accumulator outside the scheduler, - * but that's not our problem since there's nothing we can do about that. - */ + * Merge local values from a task into the corresponding accumulators previously registered + * here on the driver. + * + * Although accumulators themselves are not thread-safe, this method is called only from one + * thread, the one that runs the scheduling loop. This means we only handle one task + * completion event at a time so we don't need to worry about locking the accumulators. + * This still doesn't stop the caller from updating the accumulator outside the scheduler, + * but that's not our problem since there's nothing we can do about that. + */ private def updateAccumulators(event: CompletionEvent): Unit = { val task = event.task val stage = stageIdToStage(task.stageId) @@ -1336,9 +1524,9 @@ private[spark] class DAGScheduler( } /** - * Check [[SparkContext.SPARK_JOB_INTERRUPT_ON_CANCEL]] in job properties to see if we should - * interrupt running tasks. Returns `false` if the property value is not a boolean value - */ + * Check [[SparkContext.SPARK_JOB_INTERRUPT_ON_CANCEL]] in job properties to see if we should + * interrupt running tasks. Returns `false` if the property value is not a boolean value + */ private def shouldInterruptTaskThread(job: ActiveJob): Boolean = { if (job.properties == null) { false @@ -1357,9 +1545,9 @@ private[spark] class DAGScheduler( } /** - * Responds to a task finishing. This is called inside the event loop so it assumes that it can - * modify the scheduler's internal state. Use taskEnded() to post a task end event from outside. - */ + * Responds to a task finishing. This is called inside the event loop so it assumes that it can + * modify the scheduler's internal state. Use taskEnded() to post a task end event from outside. + */ private[scheduler] def handleTaskCompletion(event: CompletionEvent): Unit = { val task = event.task val stageId = task.stageId @@ -1372,7 +1560,7 @@ private[spark] class DAGScheduler( event.reason) if (!stageIdToStage.contains(task.stageId)) { - // The stage may have already finished when we get this event -- eg. maybe it was a + // The stage may have already finished when we get this event -- e.g. maybe it was a // speculative task. It is important that we send the TaskEnd event in any case, so listeners // are properly notified and can chose to handle it. For instance, some listeners are // doing their own accounting and if they don't get the task end event they think @@ -1473,7 +1661,8 @@ private[spark] class DAGScheduler( val status = event.result.asInstanceOf[MapStatus] val execId = status.location.executorId logDebug("ShuffleMapTask finished on " + execId) - if (failedEpoch.contains(execId) && smt.epoch <= failedEpoch(execId)) { + if (executorFailureEpoch.contains(execId) && + smt.epoch <= executorFailureEpoch(execId)) { logInfo(s"Ignoring possibly bogus $smt completion from executor $execId") } else { // The epoch of the task is acceptable (i.e., the task was launched after the most @@ -1527,7 +1716,7 @@ private[spark] class DAGScheduler( failedStage.failedAttemptIds.add(task.stageAttemptId) val shouldAbortStage = failedStage.failedAttemptIds.size >= maxConsecutiveStageAttempts || - disallowStageRetryForTest + disallowStageRetryForTest // It is likely that we receive multiple FetchFailed for a single stage (because we have // multiple tasks running concurrently on different executors). In that case, it is @@ -1674,10 +1863,19 @@ private[spark] class DAGScheduler( // TODO: mark the executor as failed only if there were lots of fetch failures on it if (bmAddress != null) { - val hostToUnregisterOutputs = if (env.blockManager.externalShuffleServiceEnabled && - unRegisterOutputOnHostOnFetchFailure) { - // We had a fetch failure with the external shuffle service, so we - // assume all shuffle data on the node is bad. + val externalShuffleServiceEnabled = env.blockManager.externalShuffleServiceEnabled + val isHostDecommissioned = taskScheduler + .getExecutorDecommissionState(bmAddress.executorId) + .exists(_.workerHost.isDefined) + + // Shuffle output of all executors on host `bmAddress.host` may be lost if: + // - External shuffle service is enabled, so we assume that all shuffle data on node is + // bad. + // - Host is decommissioned, thus all executors on that host will die. + val shuffleOutputOfEntireHostLost = externalShuffleServiceEnabled || + isHostDecommissioned + val hostToUnregisterOutputs = if (shuffleOutputOfEntireHostLost + && unRegisterOutputOnHostOnFetchFailure) { Some(bmAddress.host) } else { // Unregister shuffle data just for one executor (we don't have any @@ -1688,7 +1886,14 @@ private[spark] class DAGScheduler( execId = bmAddress.executorId, fileLost = true, hostToUnregisterOutputs = hostToUnregisterOutputs, - maybeEpoch = Some(task.epoch)) + maybeEpoch = Some(task.epoch), + // shuffleFileLostEpoch is ignored when a host is decommissioned because some + // decommissioned executors on that host might have been removed before this fetch + // failure and might have bumped up the shuffleFileLostEpoch. We ignore that, and + // proceed with unconditional removal of shuffle outputs from all executors on that + // host, including from those that we still haven't confirmed as lost due to heartbeat + // delays. + ignoreShuffleFileLostEpoch = isHostDecommissioned) } } @@ -1716,7 +1921,9 @@ private[spark] class DAGScheduler( // killAllTaskAttempts will fail if a SchedulerBackend does not implement killTask. val reason = s"Task $task from barrier stage $failedStage (${failedStage.name}) " + "failed." - taskScheduler.killAllTaskAttempts(stageId, interruptThread = false, reason) + val job = jobIdToActiveJob.get(failedStage.firstJobId) + val shouldInterrupt = job.exists(j => shouldInterruptTaskThread(j)) + taskScheduler.killAllTaskAttempts(stageId, shouldInterrupt, reason) } catch { case e: UnsupportedOperationException => // Cannot continue with barrier stage if failed to cancel zombie barrier tasks. @@ -1731,8 +1938,8 @@ private[spark] class DAGScheduler( // TODO Refactor the failure handling logic to combine similar code with that of // FetchFailed. val shouldAbortStage = - failedStage.failedAttemptIds.size >= maxConsecutiveStageAttempts || - disallowStageRetryForTest + failedStage.failedAttemptIds.size >= maxConsecutiveStageAttempts || + disallowStageRetryForTest if (shouldAbortStage) { val abortMessage = if (disallowStageRetryForTest) { @@ -1778,17 +1985,17 @@ private[spark] class DAGScheduler( handleResubmittedFailure(task, stage) case _: TaskCommitDenied => - // Do nothing here, left up to the TaskScheduler to decide how to handle denied commits + // Do nothing here, left up to the TaskScheduler to decide how to handle denied commits case _: ExceptionFailure | _: TaskKilled => - // Nothing left to do, already handled above for accumulator updates. + // Nothing left to do, already handled above for accumulator updates. case TaskResultLost => - // Do nothing here; the TaskScheduler handles these failures and resubmits the task. + // Do nothing here; the TaskScheduler handles these failures and resubmits the task. case _: ExecutorLostFailure | UnknownReason => - // Unrecognized failure - also do nothing. If the task fails repeatedly, the TaskScheduler - // will abort the job. + // Unrecognized failure - also do nothing. If the task fails repeatedly, the TaskScheduler + // will abort the job. } } @@ -1815,89 +2022,113 @@ private[spark] class DAGScheduler( } /** - * Responds to an executor being lost. This is called inside the event loop, so it assumes it can - * modify the scheduler's internal state. Use executorLost() to post a loss event from outside. - * - * We will also assume that we've lost all shuffle blocks associated with the executor if the - * executor serves its own blocks (i.e., we're not using external shuffle), the entire slave - * is lost (likely including the shuffle service), or a FetchFailed occurred, in which case we - * presume all shuffle data related to this executor to be lost. - * - * Optionally the epoch during which the failure was caught can be passed to avoid allowing - * stray fetch failures from possibly retriggering the detection of a node as lost. - */ + * Responds to an executor being lost. This is called inside the event loop, so it assumes it can + * modify the scheduler's internal state. Use executorLost() to post a loss event from outside. + * + * We will also assume that we've lost all shuffle blocks associated with the executor if the + * executor serves its own blocks (i.e., we're not using an external shuffle service), or the + * entire Standalone worker is lost. + */ private[scheduler] def handleExecutorLost( - execId: String, - workerLost: Boolean): Unit = { + execId: String, + workerHost: Option[String]): Unit = { // if the cluster manager explicitly tells us that the entire worker was lost, then // we know to unregister shuffle output. (Note that "worker" specifically refers to the process // from a Standalone cluster, where the shuffle service lives in the Worker.) val remoteShuffleClass = classOf[RemoteShuffleManager].getName val remoteShuffleEnabled = env.conf.get("spark.shuffle.manager") == remoteShuffleClass - // If remote shuffle is enabled, shuffle files will be taken care of by remote storage, the - // unregistering and rerun of certain tasks are not needed. - val fileLost = - !remoteShuffleEnabled && (workerLost || !env.blockManager.externalShuffleServiceEnabled) + val fileLost = !remoteShuffleEnabled && (workerHost.isDefined || !env.blockManager.externalShuffleServiceEnabled) removeExecutorAndUnregisterOutputs( execId = execId, fileLost = fileLost, - hostToUnregisterOutputs = None, + hostToUnregisterOutputs = workerHost, maybeEpoch = None) } + /** + * Handles removing an executor from the BlockManagerMaster as well as unregistering shuffle + * outputs for the executor or optionally its host. + * + * @param execId executor to be removed + * @param fileLost If true, indicates that we assume we've lost all shuffle blocks associated + * with the executor; this happens if the executor serves its own blocks (i.e., we're not + * using an external shuffle service), the entire Standalone worker is lost, or a FetchFailed + * occurred (in which case we presume all shuffle data related to this executor to be lost). + * @param hostToUnregisterOutputs (optional) executor host if we're unregistering all the + * outputs on the host + * @param maybeEpoch (optional) the epoch during which the failure was caught (this prevents + * reprocessing for follow-on fetch failures) + */ private def removeExecutorAndUnregisterOutputs( - execId: String, - fileLost: Boolean, - hostToUnregisterOutputs: Option[String], - maybeEpoch: Option[Long] = None): Unit = { + execId: String, + fileLost: Boolean, + hostToUnregisterOutputs: Option[String], + maybeEpoch: Option[Long] = None, + ignoreShuffleFileLostEpoch: Boolean = false): Unit = { val currentEpoch = maybeEpoch.getOrElse(mapOutputTracker.getEpoch) - if (!failedEpoch.contains(execId) || failedEpoch(execId) < currentEpoch) { - failedEpoch(execId) = currentEpoch - logInfo("Executor lost: %s (epoch %d)".format(execId, currentEpoch)) + logDebug(s"Considering removal of executor $execId; " + + s"fileLost: $fileLost, currentEpoch: $currentEpoch") + if (!executorFailureEpoch.contains(execId) || executorFailureEpoch(execId) < currentEpoch) { + executorFailureEpoch(execId) = currentEpoch + logInfo(s"Executor lost: $execId (epoch $currentEpoch)") + if (pushBasedShuffleEnabled) { + // Remove fetchFailed host in the shuffle push merger list for push based shuffle + hostToUnregisterOutputs.foreach( + host => blockManagerMaster.removeShufflePushMergerLocation(host)) + } blockManagerMaster.removeExecutor(execId) - if (fileLost) { + clearCacheLocs() + } + if (fileLost) { + val remove = if (ignoreShuffleFileLostEpoch) { + true + } else if (!shuffleFileLostEpoch.contains(execId) || + shuffleFileLostEpoch(execId) < currentEpoch) { + shuffleFileLostEpoch(execId) = currentEpoch + true + } else { + false + } + if (remove) { hostToUnregisterOutputs match { case Some(host) => - logInfo("Shuffle files lost for host: %s (epoch %d)".format(host, currentEpoch)) + logInfo(s"Shuffle files lost for host: $host (epoch $currentEpoch)") mapOutputTracker.removeOutputsOnHost(host) case None => - logInfo("Shuffle files lost for executor: %s (epoch %d)".format(execId, currentEpoch)) + logInfo(s"Shuffle files lost for executor: $execId (epoch $currentEpoch)") mapOutputTracker.removeOutputsOnExecutor(execId) } - clearCacheLocs() - - } else { - logDebug("Additional executor lost message for %s (epoch %d)".format(execId, currentEpoch)) } } } /** - * Responds to a worker being removed. This is called inside the event loop, so it assumes it can - * modify the scheduler's internal state. Use workerRemoved() to post a loss event from outside. - * - * We will assume that we've lost all shuffle blocks associated with the host if a worker is - * removed, so we will remove them all from MapStatus. - * - * @param workerId identifier of the worker that is removed. - * @param host host of the worker that is removed. - * @param message the reason why the worker is removed. - */ + * Responds to a worker being removed. This is called inside the event loop, so it assumes it can + * modify the scheduler's internal state. Use workerRemoved() to post a loss event from outside. + * + * We will assume that we've lost all shuffle blocks associated with the host if a worker is + * removed, so we will remove them all from MapStatus. + * + * @param workerId identifier of the worker that is removed. + * @param host host of the worker that is removed. + * @param message the reason why the worker is removed. + */ private[scheduler] def handleWorkerRemoved( - workerId: String, - host: String, - message: String): Unit = { + workerId: String, + host: String, + message: String): Unit = { logInfo("Shuffle files lost for worker %s on host %s".format(workerId, host)) mapOutputTracker.removeOutputsOnHost(host) clearCacheLocs() } private[scheduler] def handleExecutorAdded(execId: String, host: String): Unit = { - // remove from failedEpoch(execId) ? - if (failedEpoch.contains(execId)) { + // remove from executorFailureEpoch(execId) ? + if (executorFailureEpoch.contains(execId)) { logInfo("Host added was in lost list earlier: " + host) - failedEpoch -= execId + executorFailureEpoch -= execId } + shuffleFileLostEpoch -= execId } private[scheduler] def handleStageCancellation(stageId: Int, reason: Option[String]): Unit = { @@ -1928,12 +2159,12 @@ private[spark] class DAGScheduler( } /** - * Marks a stage as finished and removes it from the list of running stages. - */ + * Marks a stage as finished and removes it from the list of running stages. + */ private def markStageAsFinished( - stage: Stage, - errorMessage: Option[String] = None, - willRetry: Boolean = false): Unit = { + stage: Stage, + errorMessage: Option[String] = None, + willRetry: Boolean = false): Unit = { val serviceTime = stage.latestInfo.submissionTime match { case Some(t) => "%.03f".format((clock.getTimeMillis() - t) / 1000.0) case _ => "Unknown" @@ -1960,13 +2191,13 @@ private[spark] class DAGScheduler( } /** - * Aborts all jobs depending on a particular Stage. This is called in response to a task set - * being canceled by the TaskScheduler. Use taskSetFailed() to inject this event from outside. - */ + * Aborts all jobs depending on a particular Stage. This is called in response to a task set + * being canceled by the TaskScheduler. Use taskSetFailed() to inject this event from outside. + */ private[scheduler] def abortStage( - failedStage: Stage, - reason: String, - exception: Option[Throwable]): Unit = { + failedStage: Stage, + reason: String, + exception: Option[Throwable]): Unit = { if (!stageIdToStage.contains(failedStage.id)) { // Skip all the actions if the stage has been removed. return @@ -2019,9 +2250,9 @@ private[spark] class DAGScheduler( /** Fails a job and all stages that are only used by that job, and cleans up relevant state. */ private def failJobAndIndependentStages( - job: ActiveJob, - failureReason: String, - exception: Option[Throwable] = None): Unit = { + job: ActiveJob, + failureReason: String, + exception: Option[Throwable] = None): Unit = { if (cancelRunningIndependentStages(job, failureReason)) { // SPARK-15783 important to cleanup state first, just for tests where we have some asserts // against the state. Otherwise we have a *little* bit of flakiness in the tests. @@ -2065,30 +2296,30 @@ private[spark] class DAGScheduler( } /** - * Gets the locality information associated with a partition of a particular RDD. - * - * This method is thread-safe and is called from both DAGScheduler and SparkContext. - * - * @param rdd whose partitions are to be looked at - * @param partition to lookup locality information for - * @return list of machines that are preferred by the partition - */ + * Gets the locality information associated with a partition of a particular RDD. + * + * This method is thread-safe and is called from both DAGScheduler and SparkContext. + * + * @param rdd whose partitions are to be looked at + * @param partition to lookup locality information for + * @return list of machines that are preferred by the partition + */ private[spark] def getPreferredLocs(rdd: RDD[_], partition: Int): Seq[TaskLocation] = { getPreferredLocsInternal(rdd, partition, new HashSet) } /** - * Recursive implementation for getPreferredLocs. - * - * This method is thread-safe because it only accesses DAGScheduler state through thread-safe - * methods (getCacheLocs()); please be careful when modifying this method, because any new - * DAGScheduler state accessed by it may require additional synchronization. - */ + * Recursive implementation for getPreferredLocs. + * + * This method is thread-safe because it only accesses DAGScheduler state through thread-safe + * methods (getCacheLocs()); please be careful when modifying this method, because any new + * DAGScheduler state accessed by it may require additional synchronization. + */ private def getPreferredLocsInternal( - rdd: RDD[_], - partition: Int, - visited: HashSet[(RDD[_], Int)]): Seq[TaskLocation] = { + rdd: RDD[_], + partition: Int, + visited: HashSet[(RDD[_], Int)]): Seq[TaskLocation] = { // If the partition has already been visited, no need to re-visit. // This avoids exponential path exploration. SPARK-695 if (!visited.add((rdd, partition))) { @@ -2150,8 +2381,8 @@ private[scheduler] class DAGSchedulerEventProcessLoop(dagScheduler: DAGScheduler private[this] val timer = dagScheduler.metricsSource.messageProcessingTimer /** - * The main event loop of the DAG scheduler. - */ + * The main event loop of the DAG scheduler. + */ override def onReceive(event: DAGSchedulerEvent): Unit = { val timerContext = timer.time() try { @@ -2184,11 +2415,12 @@ private[scheduler] class DAGSchedulerEventProcessLoop(dagScheduler: DAGScheduler dagScheduler.handleExecutorAdded(execId, host) case ExecutorLost(execId, reason) => - val workerLost = reason match { - case SlaveLost(_, true) => true - case _ => false + val workerHost = reason match { + case ExecutorProcessLost(_, workerHost, _) => workerHost + case ExecutorDecommission(workerHost) => workerHost + case _ => None } - dagScheduler.handleExecutorLost(execId, workerLost) + dagScheduler.handleExecutorLost(execId, workerHost) case WorkerRemoved(workerId, host, message) => dagScheduler.handleWorkerRemoved(workerId, host, message) @@ -2199,6 +2431,12 @@ private[scheduler] class DAGSchedulerEventProcessLoop(dagScheduler: DAGScheduler case SpeculativeTaskSubmitted(task) => dagScheduler.handleSpeculativeTaskSubmitted(task) + case UnschedulableTaskSetAdded(stageId, stageAttemptId) => + dagScheduler.handleUnschedulableTaskSetAdded(stageId, stageAttemptId) + + case UnschedulableTaskSetRemoved(stageId, stageAttemptId) => + dagScheduler.handleUnschedulableTaskSetRemoved(stageId, stageAttemptId) + case GettingResultEvent(taskInfo) => dagScheduler.handleGetTaskResult(taskInfo) diff --git a/shuffle-hadoop/src/main/scala/org/apache/spark/shuffle/remote/RemoteShuffleManager.scala b/shuffle-hadoop/src/main/scala/org/apache/spark/shuffle/remote/RemoteShuffleManager.scala index 3b84e6b1..0a623d79 100644 --- a/shuffle-hadoop/src/main/scala/org/apache/spark/shuffle/remote/RemoteShuffleManager.scala +++ b/shuffle-hadoop/src/main/scala/org/apache/spark/shuffle/remote/RemoteShuffleManager.scala @@ -89,25 +89,6 @@ private[spark] class RemoteShuffleManager(private val conf: SparkConf) extends S * Called on executors by reduce tasks. */ override def getReader[K, C]( - handle: ShuffleHandle, - startPartition: Int, - endPartition: Int, - context: TaskContext, - metrics: ShuffleReadMetricsReporter): ShuffleReader[K, C] = { - - val blocksByAddress = SparkEnv.get.mapOutputTracker.getMapSizesByExecutorId( - handle.shuffleId, startPartition, endPartition) - - new RemoteShuffleReader( - handle.asInstanceOf[BaseShuffleHandle[K, _, C]], - shuffleBlockResolver, - blocksByAddress, - context, - metrics, - shouldBatchFetch = canUseBatchFetch(startPartition, endPartition, context)) - } - - override def getReaderForRange[K, C]( handle: ShuffleHandle, startMapIndex: Int, endMapIndex: Int, @@ -116,7 +97,7 @@ private[spark] class RemoteShuffleManager(private val conf: SparkConf) extends S context: TaskContext, metrics: ShuffleReadMetricsReporter): ShuffleReader[K, C] = { - val blocksByAddress = SparkEnv.get.mapOutputTracker.getMapSizesByRange( + val blocksByAddress = SparkEnv.get.mapOutputTracker.getMapSizesByExecutorId( handle.shuffleId, startMapIndex, endMapIndex, startPartition, endPartition) new RemoteShuffleReader( diff --git a/shuffle-hadoop/src/test/scala/org/apache/spark/util/collection/RemoteAppendOnlyMapSuite.scala b/shuffle-hadoop/src/test/scala/org/apache/spark/util/collection/RemoteAppendOnlyMapSuite.scala index 6879d435..533ba798 100644 --- a/shuffle-hadoop/src/test/scala/org/apache/spark/util/collection/RemoteAppendOnlyMapSuite.scala +++ b/shuffle-hadoop/src/test/scala/org/apache/spark/util/collection/RemoteAppendOnlyMapSuite.scala @@ -21,7 +21,6 @@ import scala.collection.mutable.ArrayBuffer import scala.concurrent.duration._ import scala.ref.WeakReference -import org.scalatest.Matchers import org.scalatest.concurrent.Eventually import org.apache.spark._ @@ -38,8 +37,7 @@ import org.apache.spark.util.CompletionIterator */ class RemoteAppendOnlyMapSuite extends SparkFunSuite with LocalSparkContext - with Eventually - with Matchers{ + with Eventually { import TestUtils.{assertNotSpilled, assertSpilled} private val allCompressionCodecs = CompressionCodec.ALL_COMPRESSION_CODECS @@ -250,7 +248,7 @@ class RemoteAppendOnlyMapSuite extends SparkFunSuite } test("spilling with compression and encryption") { - testSimpleSpilling(Some(CompressionCodec.DEFAULT_COMPRESSION_CODEC), encrypt = true) + testSimpleSpilling(Some(CompressionCodec.FALLBACK_COMPRESSION_CODEC), encrypt = true) } /**