Skip to content

[CELEBORN-2315] Add iterator fully-consumed validation after shuffle write#3672

Open
xumingming wants to merge 1 commit into
apache:mainfrom
xumingming:iterator-fully-consumed-check
Open

[CELEBORN-2315] Add iterator fully-consumed validation after shuffle write#3672
xumingming wants to merge 1 commit into
apache:mainfrom
xumingming:iterator-fully-consumed-check

Conversation

@xumingming
Copy link
Copy Markdown
Contributor

What changes were proposed in this pull request?

Adds a post-write safety check to HashBasedShuffleWriter and SortBasedShuffleWriter: after the write loop completes, verify the input iterator was fully consumed. If records remain, kill the task with TaskKilledException. This guards against silent data loss.

Why are the changes needed?

It could give another layer of correctness guarantee.

Does this PR resolve a correctness bug?

Enhance correctness guarantee.

Does this PR introduce any user-facing change?

No.

How was this patch tested?

UT

@xumingming xumingming force-pushed the iterator-fully-consumed-check branch 2 times, most recently from 5a50c71 to dbd6473 Compare April 23, 2026 12:33
@xumingming
Copy link
Copy Markdown
Contributor Author

@gauravkm @RexXiong @SteNicholas Could you also take a look at this one?

@xumingming
Copy link
Copy Markdown
Contributor Author

@RexXiong @SteNicholas @gauravkm Gentle ping :)

@afterincomparableyum
Copy link
Copy Markdown
Contributor

i’ll help take a look at this PR over the next couple days

Copy link
Copy Markdown

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR adds a correctness guard to Celeborn’s Spark shuffle writers (Spark 2 and Spark 3 variants): after finishing the write path, it validates that the upstream records iterator was fully consumed and kills the task if it wasn’t, reducing the risk of silent data loss.

Changes:

  • Add SparkUtils.assertIteratorFullyConsumed(...) helper and invoke it at the end of shuffle-writer close paths.
  • Refactor Hash/Sort-based writers’ write flows to propagate an iteratorHasNext signal from the write loop to close/validation.
  • Extend TaskInterruptedHelper to support an optional message in TaskKilledException and add unit tests for the new assertion.

Reviewed changes

Copilot reviewed 9 out of 9 changed files in this pull request and generated 8 comments.

Show a summary per file
File Description
client-spark/spark-3/src/main/java/org/apache/spark/shuffle/celeborn/SparkUtils.java Adds assertIteratorFullyConsumed helper (Spark 3).
client-spark/spark-2/src/main/java/org/apache/spark/shuffle/celeborn/SparkUtils.java Adds assertIteratorFullyConsumed helper (Spark 2).
client-spark/spark-3/src/main/java/org/apache/spark/shuffle/celeborn/SortBasedShuffleWriter.java Returns iterator-consumption status from doWrite and validates on close.
client-spark/spark-3/src/main/java/org/apache/spark/shuffle/celeborn/HashBasedShuffleWriter.java Same iterator-consumption validation wiring for hash-based writer.
client-spark/spark-2/src/main/java/org/apache/spark/shuffle/celeborn/SortBasedShuffleWriter.java Same iterator-consumption validation wiring (Spark 2).
client-spark/spark-2/src/main/java/org/apache/spark/shuffle/celeborn/HashBasedShuffleWriter.java Same iterator-consumption validation wiring (Spark 2).
client-spark/common/src/main/java/org/apache/spark/shuffle/celeborn/TaskInterruptedHelper.java Adds overload to include a message in TaskKilledException.
client-spark/spark-3/src/test/java/org/apache/spark/shuffle/celeborn/CelebornShuffleWriterSuiteBase.java Adds unit tests for the new iterator-consumed assertion and mocks kill reason.
client-spark/spark-2/src/test/java/org/apache/spark/shuffle/celeborn/CelebornShuffleWriterSuiteBase.java Adds unit tests for the new iterator-consumed assertion and mocks kill reason.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

}
}

boolean doWrite(scala.collection.Iterator<Product2<K, V>> records)
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

some of the doWrite(s) is called from testing code(SortBasedShuffleWriterSuiteJ), so I'd prefer keep all doWrite(...) package-private to be consistent and simple.

throw new UnsupportedOperationException(
"When using map side combine, an aggregator must be specified.");
}
scala.collection.Iterator combinedIterator =
Comment on lines +160 to +177
boolean doWrite(scala.collection.Iterator<Product2<K, V>> records) throws IOException {
if (canUseFastWrite()) {
fastWrite0(records);
return records.hasNext();
} else if (dep.mapSideCombine()) {
if (dep.aggregator().isEmpty()) {
throw new UnsupportedOperationException(
"When using map side combine, an aggregator must be specified.");
}
scala.collection.Iterator combinedIterator =
dep.aggregator().get().combineValuesByKey(records, taskContext);
write0(combinedIterator);
return combinedIterator.hasNext();
} else {
write0(records);
return records.hasNext();
}
}
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

some of the doWrite(s) is called from testing code(SortBasedShuffleWriterSuiteJ), so I'd prefer keep all doWrite(...) package-private to be consistent and simple.

Comment on lines +181 to +199
boolean doWrite(scala.collection.Iterator<Product2<K, V>> records)
throws IOException, InterruptedException {
if (canUseFastWrite()) {
fastWrite0(records);
return records.hasNext();
} else if (dep.mapSideCombine()) {
if (dep.aggregator().isEmpty()) {
throw new UnsupportedOperationException(
"When using map side combine, an aggregator must be specified.");
}
scala.collection.Iterator combinedIterator =
dep.aggregator().get().combineValuesByKey(records, taskContext);
write0(combinedIterator);
return combinedIterator.hasNext();
} else {
write0(records);
return records.hasNext();
}
}
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

some of the doWrite(s) is called from testing code(SortBasedShuffleWriterSuiteJ), so I'd prefer keep all doWrite(...) package-private to be consistent and simple.

Comment on lines 368 to 376
shuffleClient.pushMergedData(shuffleId, mapId, encodedAttemptId);

updateMapStatus();

SparkUtils.assertIteratorFullyConsumed(iteratorHasNext);

sendBufferPool.returnBuffer(sendBuffers);
sendBuffers = null;
sendOffsets = null;
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a good catch. I have moved SparkUtils.assertIteratorFullyConsumed down to the line just above "shuffleClient.mapperEnd", so after the changes:

  • All resources are cleaned before we kill the task.
  • The MapperEnd event is not submit.

…write

Adds a post-write safety check to HashBasedShuffleWriter and SortBasedShuffleWriter:
after the write loop completes, verify the input iterator was fully consumed.
If records remain, kill the task with TaskKilledException. This guards against
silent data loss.
@xumingming xumingming force-pushed the iterator-fully-consumed-check branch from dbd6473 to bba2479 Compare May 13, 2026 09:06
@xumingming
Copy link
Copy Markdown
Contributor Author

@SteNicholas I have made all the necessary changes, can you take a look at again?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants