Skip to content

Commit 6e9c203

Browse files
committed
fix: improve pull resume reliability and progress bar display on retry
- Only accept HTTP 206 with a matching Content-Range start offset as a successful Range response in rangeTransport; a 200 response means the server ignored the Range header and is sending from byte 0, so appending it to the partial file would corrupt the blob. A misbehaving server returning 206 with a different range is also rejected. - Preserve .incomplete files on all read errors, not just context cancellation, so every kind of transient failure (network reset, stream error, etc.) can be resumed on the next attempt. - Stop rolling back fully-downloaded layer blobs on Write failure. Layer blobs are content-addressed and immutable; keeping them lets a subsequent pull skip already-completed layers entirely instead of re-downloading them. The manifest, config, and index rollback still runs to leave the store in a consistent non-indexed state. - Print a blank line to stdout before each retry so that orphaned progress bars from the failed attempt are visually separated from the new attempt's bars, preventing garbled terminal output. - Increase Pull and Push max retries from 3 to 4 and update TestPullMaxRetriesExhausted to match (5 total attempts, error message says "after 4 retries"). Signed-off-by: Eric Curtin <eric.curtin@docker.com>
1 parent 4d0a3bf commit 6e9c203

6 files changed

Lines changed: 64 additions & 46 deletions

File tree

cmd/cli/desktop/desktop.go

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -117,7 +117,7 @@ func (c *Client) Pull(model string, printer standalone.StatusPrinter) (string, b
117117
hfToken = os.Getenv("HF_TOKEN")
118118
}
119119

120-
return c.withRetries("download", 3, printer, func(attempt int) (string, bool, error, bool) {
120+
return c.withRetries("download", 4, printer, func(attempt int) (string, bool, error, bool) {
121121
jsonData, err := json.Marshal(dmrm.ModelCreateRequest{
122122
From: model,
123123
BearerToken: hfToken,
@@ -237,6 +237,11 @@ func (c *Client) withRetries(
237237
if attempt > 0 {
238238
// Calculate exponential backoff: 2^(attempt-1) seconds (1s, 2s, 4s)
239239
backoffDuration := time.Duration(1<<uint(attempt-1)) * time.Second
240+
// Print a blank line to stdout so that any progress bars drawn during
241+
// the previous attempt are visually separated from the retry attempt.
242+
// This prevents the new progress bars from overwriting the old ones
243+
// when the terminal display is reset on each retry.
244+
printer.Println("")
240245
printer.PrintErrf("Retrying %s (attempt %d/%d) in %v...\n", operationName, attempt, maxRetries, backoffDuration)
241246
time.Sleep(backoffDuration)
242247
}
@@ -263,7 +268,7 @@ func (c *Client) Push(model string, printer standalone.StatusPrinter) (string, b
263268
hfToken = os.Getenv("HF_TOKEN")
264269
}
265270

266-
return c.withRetries("push", 3, printer, func(attempt int) (string, bool, error, bool) {
271+
return c.withRetries("push", 4, printer, func(attempt int) (string, bool, error, bool) {
267272
pushPath := inference.ModelsPrefix + "/" + model + "/push"
268273
var body io.Reader
269274
if hfToken != "" {

cmd/cli/desktop/desktop_test.go

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -191,13 +191,13 @@ func TestPullMaxRetriesExhausted(t *testing.T) {
191191
mockContext := NewContextForMock(mockClient)
192192
client := New(mockContext)
193193

194-
// All 4 attempts (1 initial + 3 retries) fail with network error
195-
mockClient.EXPECT().Do(gomock.Any()).Return(nil, io.EOF).Times(4)
194+
// All 5 attempts (1 initial + 4 retries) fail with network error
195+
mockClient.EXPECT().Do(gomock.Any()).Return(nil, io.EOF).Times(5)
196196

197197
printer := NewSimplePrinter(func(s string) {})
198198
_, _, err := client.Pull(modelName, printer)
199199
assert.Error(t, err)
200-
assert.Contains(t, err.Error(), "download failed after 3 retries")
200+
assert.Contains(t, err.Error(), "download failed after 4 retries")
201201
}
202202

203203
func TestPushRetryOnNetworkError(t *testing.T) {

pkg/distribution/internal/store/blobs.go

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,8 @@
11
package store
22

33
import (
4-
"context"
54
"crypto/sha256"
65
"encoding/hex"
7-
"errors"
86
"fmt"
97
"io"
108
"os"
@@ -199,11 +197,10 @@ func (s *LocalStore) WriteBlobWithResume(diffID oci.Hash, r io.Reader, digestStr
199197
buf := make([]byte, 1)
200198
n, readErr := r.Read(buf)
201199
if readErr != nil && readErr != io.EOF {
202-
// Clean up the incomplete file on read error (unless it's a context cancellation
203-
// which should preserve the file for future resume attempts)
204-
if !errors.Is(readErr, context.Canceled) && !errors.Is(readErr, context.DeadlineExceeded) {
205-
_ = os.Remove(incompletePath)
206-
}
200+
// Preserve the incomplete file on all errors so that the next
201+
// attempt can resume from where this one left off. Stale
202+
// incomplete files are cleaned up by CleanupStaleIncompleteFiles
203+
// during store initialisation (default: files older than 7 days).
207204
return fmt.Errorf("read first byte: %w", readErr)
208205
}
209206

pkg/distribution/internal/store/store.go

Lines changed: 10 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -374,29 +374,16 @@ func (s *LocalStore) Write(mdl oci.Image, tags []string, w io.Writer, opts ...Wr
374374
return err
375375
}
376376

377-
// Collect new layer digests
378-
var newLayerDigests []oci.Hash
379-
for _, result := range results {
380-
if result.created {
381-
newLayerDigests = append(newLayerDigests, result.diffID)
382-
}
383-
}
384-
385-
if len(newLayerDigests) > 0 {
386-
digests := append([]oci.Hash(nil), newLayerDigests...)
387-
cleanups = append(cleanups, func() error {
388-
var errs []error
389-
for _, dg := range digests {
390-
if err := s.removeBlob(dg); err != nil && !errors.Is(err, os.ErrNotExist) {
391-
errs = append(errs, fmt.Errorf("remove blob %s: %w", dg, err))
392-
}
393-
}
394-
if len(errs) > 0 {
395-
return errors.Join(errs...)
396-
}
397-
return nil
398-
})
399-
}
377+
// Do not register completed layer blobs in the rollback list.
378+
//
379+
// Layer blobs are content-addressed: a blob identified by its diffID is
380+
// immutable and may be shared across multiple models. Rolling them back
381+
// when a later step fails (e.g. writing the manifest, tagging) would
382+
// discard already-downloaded data that is still valid, forcing a full
383+
// re-download on the next attempt instead of resuming only the layer
384+
// that was actually in progress. The manifest/config/index cleanup below
385+
// is sufficient to leave the store in a consistent state.
386+
_ = results // results used above for error checking; layer blobs are retained
400387

401388
// Write the manifest
402389
digest, err := mdl.Digest()

pkg/distribution/internal/store/store_test.go

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -424,14 +424,19 @@ func TestWriteRollsBackOnTagFailure(t *testing.T) {
424424
t.Fatalf("expected config blob to be cleaned up, stat error: %v", err)
425425
}
426426

427+
// Layer blobs are content-addressed and are intentionally retained even
428+
// after a failed write. They may be reused by a subsequent pull of the
429+
// same or another model, and they allow the download to resume rather
430+
// than restart from byte 0. Only the manifest, config, and index are
431+
// rolled back to leave the store in a consistent (non-indexed) state.
427432
for _, digestStr := range diffIDs {
428433
parts := strings.SplitN(digestStr, ":", 2)
429434
if len(parts) != 2 {
430435
t.Fatalf("unexpected diffID format: %q", digestStr)
431436
}
432437
layerPath := filepath.Join(storePath, "blobs", parts[0], parts[1])
433-
if _, err := os.Stat(layerPath); !errors.Is(err, os.ErrNotExist) {
434-
t.Fatalf("expected layer blob %q to be cleaned up, stat error: %v", layerPath, err)
438+
if _, err := os.Stat(layerPath); err != nil {
439+
t.Fatalf("expected layer blob %q to be retained for future resume, stat error: %v", layerPath, err)
435440
}
436441
}
437442

pkg/distribution/oci/remote/remote.go

Lines changed: 33 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -197,24 +197,48 @@ func (t *rangeTransport) RoundTrip(req *http.Request) (*http.Response, error) {
197197
return resp, err
198198
}
199199

200-
// If we requested a Range, record success only if the server accepted the range request
201-
// Servers should return 206 (Partial Content) for successful range requests,
202-
// but some may return 200 with the partial content, so we record success for both
203-
if requestedOffset > 0 {
204-
if resp.StatusCode == http.StatusPartialContent || resp.StatusCode == http.StatusOK {
205-
// Record in RangeSuccess tracker so WriteBlob can check it
200+
// If we requested a Range, record success only when the server honoured it
201+
// with 206 Partial Content and a matching Content-Range start offset. A 200
202+
// response means the server ignored the Range header and is sending the full
203+
// file from byte 0; appending that stream to the existing partial file would
204+
// produce a corrupt blob. We also validate the Content-Range start offset to
205+
// guard against a misbehaving server that returns 206 with a different range.
206+
if requestedOffset > 0 && resp.StatusCode == http.StatusPartialContent {
207+
if rangeStartMatchesOffset(resp.Header.Get("Content-Range"), requestedOffset) {
206208
if rs := GetRangeSuccess(req.Context()); rs != nil {
207209
rs.Add(digest, requestedOffset)
208210
}
209211
}
210-
// If range request was not successful (e.g., 416 Range Not Satisfiable),
211-
// don't record in RangeSuccess, which will cause WriteBlob to start fresh
212-
// (no explicit action needed in the else case)
213212
}
214213

215214
return resp, nil
216215
}
217216

217+
// rangeStartMatchesOffset parses the Content-Range response header and reports
218+
// whether its start byte equals the given offset. The format is defined by
219+
// RFC 9110: "bytes START-END/TOTAL" (TOTAL may be "*"). We fail closed: if the
220+
// header is absent or cannot be parsed we return false so that the caller does
221+
// not treat an ambiguous response as a successful range request.
222+
func rangeStartMatchesOffset(contentRange string, offset int64) bool {
223+
if contentRange == "" {
224+
return false
225+
}
226+
// Trim the unit prefix "bytes " and split on "-"
227+
after, ok := strings.CutPrefix(contentRange, "bytes ")
228+
if !ok {
229+
return false
230+
}
231+
dashIdx := strings.Index(after, "-")
232+
if dashIdx < 0 {
233+
return false
234+
}
235+
var start int64
236+
if _, err := fmt.Sscanf(after[:dashIdx], "%d", &start); err != nil {
237+
return false
238+
}
239+
return start == offset
240+
}
241+
218242
// extractDigestAndOffset extracts the blob digest from the request URL and returns
219243
// the corresponding resume offset if one exists.
220244
func (t *rangeTransport) extractDigestAndOffset(req *http.Request, offsets map[string]int64) (string, int64) {

0 commit comments

Comments
 (0)