[CELEBORN-2315] Add iterator fully-consumed validation after shuffle write#3672
[CELEBORN-2315] Add iterator fully-consumed validation after shuffle write#3672xumingming wants to merge 1 commit into
Conversation
5a50c71 to
dbd6473
Compare
|
@gauravkm @RexXiong @SteNicholas Could you also take a look at this one? |
|
@RexXiong @SteNicholas @gauravkm Gentle ping :) |
|
i’ll help take a look at this PR over the next couple days |
There was a problem hiding this comment.
Pull request overview
This PR adds a correctness guard to Celeborn’s Spark shuffle writers (Spark 2 and Spark 3 variants): after finishing the write path, it validates that the upstream records iterator was fully consumed and kills the task if it wasn’t, reducing the risk of silent data loss.
Changes:
- Add
SparkUtils.assertIteratorFullyConsumed(...)helper and invoke it at the end of shuffle-writer close paths. - Refactor Hash/Sort-based writers’ write flows to propagate an
iteratorHasNextsignal from the write loop to close/validation. - Extend
TaskInterruptedHelperto support an optional message inTaskKilledExceptionand add unit tests for the new assertion.
Reviewed changes
Copilot reviewed 9 out of 9 changed files in this pull request and generated 8 comments.
Show a summary per file
| File | Description |
|---|---|
| client-spark/spark-3/src/main/java/org/apache/spark/shuffle/celeborn/SparkUtils.java | Adds assertIteratorFullyConsumed helper (Spark 3). |
| client-spark/spark-2/src/main/java/org/apache/spark/shuffle/celeborn/SparkUtils.java | Adds assertIteratorFullyConsumed helper (Spark 2). |
| client-spark/spark-3/src/main/java/org/apache/spark/shuffle/celeborn/SortBasedShuffleWriter.java | Returns iterator-consumption status from doWrite and validates on close. |
| client-spark/spark-3/src/main/java/org/apache/spark/shuffle/celeborn/HashBasedShuffleWriter.java | Same iterator-consumption validation wiring for hash-based writer. |
| client-spark/spark-2/src/main/java/org/apache/spark/shuffle/celeborn/SortBasedShuffleWriter.java | Same iterator-consumption validation wiring (Spark 2). |
| client-spark/spark-2/src/main/java/org/apache/spark/shuffle/celeborn/HashBasedShuffleWriter.java | Same iterator-consumption validation wiring (Spark 2). |
| client-spark/common/src/main/java/org/apache/spark/shuffle/celeborn/TaskInterruptedHelper.java | Adds overload to include a message in TaskKilledException. |
| client-spark/spark-3/src/test/java/org/apache/spark/shuffle/celeborn/CelebornShuffleWriterSuiteBase.java | Adds unit tests for the new iterator-consumed assertion and mocks kill reason. |
| client-spark/spark-2/src/test/java/org/apache/spark/shuffle/celeborn/CelebornShuffleWriterSuiteBase.java | Adds unit tests for the new iterator-consumed assertion and mocks kill reason. |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| } | ||
| } | ||
|
|
||
| boolean doWrite(scala.collection.Iterator<Product2<K, V>> records) |
There was a problem hiding this comment.
some of the doWrite(s) is called from testing code(SortBasedShuffleWriterSuiteJ), so I'd prefer keep all doWrite(...) package-private to be consistent and simple.
| throw new UnsupportedOperationException( | ||
| "When using map side combine, an aggregator must be specified."); | ||
| } | ||
| scala.collection.Iterator combinedIterator = |
| boolean doWrite(scala.collection.Iterator<Product2<K, V>> records) throws IOException { | ||
| if (canUseFastWrite()) { | ||
| fastWrite0(records); | ||
| return records.hasNext(); | ||
| } else if (dep.mapSideCombine()) { | ||
| if (dep.aggregator().isEmpty()) { | ||
| throw new UnsupportedOperationException( | ||
| "When using map side combine, an aggregator must be specified."); | ||
| } | ||
| scala.collection.Iterator combinedIterator = | ||
| dep.aggregator().get().combineValuesByKey(records, taskContext); | ||
| write0(combinedIterator); | ||
| return combinedIterator.hasNext(); | ||
| } else { | ||
| write0(records); | ||
| return records.hasNext(); | ||
| } | ||
| } |
There was a problem hiding this comment.
some of the doWrite(s) is called from testing code(SortBasedShuffleWriterSuiteJ), so I'd prefer keep all doWrite(...) package-private to be consistent and simple.
| boolean doWrite(scala.collection.Iterator<Product2<K, V>> records) | ||
| throws IOException, InterruptedException { | ||
| if (canUseFastWrite()) { | ||
| fastWrite0(records); | ||
| return records.hasNext(); | ||
| } else if (dep.mapSideCombine()) { | ||
| if (dep.aggregator().isEmpty()) { | ||
| throw new UnsupportedOperationException( | ||
| "When using map side combine, an aggregator must be specified."); | ||
| } | ||
| scala.collection.Iterator combinedIterator = | ||
| dep.aggregator().get().combineValuesByKey(records, taskContext); | ||
| write0(combinedIterator); | ||
| return combinedIterator.hasNext(); | ||
| } else { | ||
| write0(records); | ||
| return records.hasNext(); | ||
| } | ||
| } |
There was a problem hiding this comment.
some of the doWrite(s) is called from testing code(SortBasedShuffleWriterSuiteJ), so I'd prefer keep all doWrite(...) package-private to be consistent and simple.
| shuffleClient.pushMergedData(shuffleId, mapId, encodedAttemptId); | ||
|
|
||
| updateMapStatus(); | ||
|
|
||
| SparkUtils.assertIteratorFullyConsumed(iteratorHasNext); | ||
|
|
||
| sendBufferPool.returnBuffer(sendBuffers); | ||
| sendBuffers = null; | ||
| sendOffsets = null; |
There was a problem hiding this comment.
This is a good catch. I have moved SparkUtils.assertIteratorFullyConsumed down to the line just above "shuffleClient.mapperEnd", so after the changes:
- All resources are cleaned before we kill the task.
- The MapperEnd event is not submit.
…write Adds a post-write safety check to HashBasedShuffleWriter and SortBasedShuffleWriter: after the write loop completes, verify the input iterator was fully consumed. If records remain, kill the task with TaskKilledException. This guards against silent data loss.
dbd6473 to
bba2479
Compare
|
@SteNicholas I have made all the necessary changes, can you take a look at again? |
What changes were proposed in this pull request?
Adds a post-write safety check to HashBasedShuffleWriter and SortBasedShuffleWriter: after the write loop completes, verify the input iterator was fully consumed. If records remain, kill the task with TaskKilledException. This guards against silent data loss.
Why are the changes needed?
It could give another layer of correctness guarantee.
Does this PR resolve a correctness bug?
Enhance correctness guarantee.
Does this PR introduce any user-facing change?
No.
How was this patch tested?
UT