From f6d0201f1e243827a50dbb4856d488285e473423 Mon Sep 17 00:00:00 2001 From: Soule BA Date: Mon, 26 Feb 2024 15:42:32 +0100 Subject: [PATCH 1/2] Enable pulling large files in parallel This is an attempt to better meet the expectation of users that pull large files. If implemented this will permit to pull concurrently chunks of a given artifact layer. Signed-off-by: Soule BA --- oci/client/client.go | 4 +- oci/client/pull.go | 223 +++++++++++++++++++++++++++++++++-- oci/client/pull_test.go | 21 ++++ oci/client/push_pull_test.go | 1 + oci/go.mod | 4 +- oci/go.sum | 6 + 6 files changed, 248 insertions(+), 11 deletions(-) diff --git a/oci/client/client.go b/oci/client/client.go index b3cd257a0..c816855a7 100644 --- a/oci/client/client.go +++ b/oci/client/client.go @@ -21,13 +21,15 @@ import ( "github.com/google/go-containerregistry/pkg/crane" "github.com/google/go-containerregistry/pkg/v1/remote" +"github.com/hashicorp/go-retryablehttp" "github.com/fluxcd/pkg/oci" ) // Client holds the options for accessing remote OCI registries. type Client struct { - options []crane.Option + options []crane.Option + httpClient *retryablehttp.Client } // NewClient returns an OCI client configured with the given crane options. diff --git a/oci/client/pull.go b/oci/client/pull.go index 633bf4b30..b49892a7f 100644 --- a/oci/client/pull.go +++ b/oci/client/pull.go @@ -22,13 +22,28 @@ import ( "context" "fmt" "io" + "net/http" + "net/url" "os" + "github.com/fluxcd/pkg/tar" + "github.com/google/go-containerregistry/pkg/authn" "github.com/google/go-containerregistry/pkg/crane" "github.com/google/go-containerregistry/pkg/name" - gcrv1 "github.com/google/go-containerregistry/pkg/v1" + "github.com/hashicorp/go-retryablehttp" - "github.com/fluxcd/pkg/tar" + v1 "github.com/google/go-containerregistry/pkg/v1" + "github.com/google/go-containerregistry/pkg/v1/remote" + "github.com/google/go-containerregistry/pkg/v1/remote/transport" + "golang.org/x/sync/errgroup" +) + +const ( + // thresholdForConcurrentPull is the maximum size of a layer to be extracted in one go. + // If the layer is larger than this, it will be downloaded in chunks. + thresholdForConcurrentPull = 100 * 1024 * 1024 // 100MB + // maxConcurrentPulls is the maximum number of concurrent downloads. + maxConcurrentPulls = 10 ) var ( @@ -39,8 +54,12 @@ var ( // PullOptions contains options for pulling a layer. type PullOptions struct { - layerIndex int - layerType LayerType + layerIndex int + layerType LayerType + transport http.RoundTripper + auth authn.Authenticator + keychain authn.Keychain + concurrency int } // PullOption is a function for configuring PullOptions. @@ -60,22 +79,53 @@ func WithPullLayerIndex(i int) PullOption { } } +func WithTransport(t http.RoundTripper) PullOption { + return func(o *PullOptions) { + o.transport = t + } +} + +func WithConcurrency(c int) PullOption { + return func(o *PullOptions) { + o.concurrency = c + } +} + // Pull downloads an artifact from an OCI repository and extracts the content. // It untar or copies the content to the given outPath depending on the layerType. // If no layer type is given, it tries to determine the right type by checking compressed content of the layer. -func (c *Client) Pull(ctx context.Context, url, outPath string, opts ...PullOption) (*Metadata, error) { +func (c *Client) Pull(ctx context.Context, urlString, outPath string, opts ...PullOption) (*Metadata, error) { o := &PullOptions{ layerIndex: 0, } + o.keychain = authn.DefaultKeychain for _, opt := range opts { opt(o) } - ref, err := name.ParseReference(url) + + if o.concurrency == 0 || o.concurrency > maxConcurrentPulls { + o.concurrency = maxConcurrentPulls + } + + if o.transport == nil { + transport := remote.DefaultTransport.(*http.Transport).Clone() + o.transport = transport + } + + ref, err := name.ParseReference(urlString) if err != nil { return nil, fmt.Errorf("invalid URL: %w", err) } - img, err := crane.Pull(url, c.optionsWithContext(ctx)...) + if c.httpClient == nil { + h, err := makeHttpClient(ctx, ref.Context(), *o) + if err != nil { + return nil, err + } + c.httpClient = h + } + + img, err := crane.Pull(urlString, c.optionsWithContext(ctx)...) if err != nil { return nil, err } @@ -91,7 +141,7 @@ func (c *Client) Pull(ctx context.Context, url, outPath string, opts ...PullOpti } meta := MetadataFromAnnotations(manifest.Annotations) - meta.URL = url + meta.URL = urlString meta.Digest = ref.Context().Digest(digest.String()).String() layers, err := img.Layers() @@ -107,6 +157,34 @@ func (c *Client) Pull(ctx context.Context, url, outPath string, opts ...PullOpti return nil, fmt.Errorf("index '%d' out of bound for '%d' layers in artifact", o.layerIndex, len(layers)) } + size, err := layers[o.layerIndex].Size() + if err != nil { + return nil, fmt.Errorf("failed to get layer size: %w", err) + } + + if size > thresholdForConcurrentPull { + digest, err := layers[o.layerIndex].Digest() + if err != nil { + return nil, fmt.Errorf("parsing digest failed: %w", err) + } + u := url.URL{ + Scheme: ref.Context().Scheme(), + Host: ref.Context().RegistryStr(), + Path: fmt.Sprintf("/v2/%s/blobs/%s", ref.Context().RepositoryStr(), digest.String()), + } + ok, err := c.IsRangeRequestEnabled(ctx, u) + if err != nil { + return nil, fmt.Errorf("failed to check range request support: %w", err) + } + if ok { + err = c.concurrentExtractLayer(ctx, u, layers[o.layerIndex], outPath, digest, size, o.concurrency) + if err != nil { + return nil, err + } + return meta, nil + } + } + err = extractLayer(layers[o.layerIndex], outPath, o.layerType) if err != nil { return nil, err @@ -114,8 +192,98 @@ func (c *Client) Pull(ctx context.Context, url, outPath string, opts ...PullOpti return meta, nil } +// TO DO: handle authentication handle using keychain for authentication +func (c *Client) IsRangeRequestEnabled(ctx context.Context, u url.URL) (bool, error) { + req, err := retryablehttp.NewRequest(http.MethodHead, u.String(), nil) + if err != nil { + return false, err + } + + resp, err := c.httpClient.Do(req.WithContext(ctx)) + if err != nil { + return false, err + } + + if err := transport.CheckError(resp, http.StatusOK); err != nil { + return false, err + } + + if rangeUnit := resp.Header.Get("Accept-Ranges"); rangeUnit == "bytes" { + return true, nil + } + for k, v := range resp.Header { + fmt.Printf("Header: %s, Value: %s\n", k, v) + } + return false, nil +} + +func (c *Client) concurrentExtractLayer(ctx context.Context, u url.URL, layer v1.Layer, path string, digest v1.Hash, size int64, concurrency int) error { + chunkSize := size / int64(concurrency) + chunks := make([][]byte, concurrency+1) + diff := size % int64(concurrency) + + g, ctx := errgroup.WithContext(ctx) + for i := 0; i < concurrency; i++ { + i := i + g.Go(func() (err error) { + start, end := int64(i)*chunkSize, int64(i+1)*chunkSize + if i == concurrency-1 { + end += diff + } + req, err := retryablehttp.NewRequest(http.MethodGet, u.String(), nil) + if err != nil { + return fmt.Errorf("failed to create a new request: %w", err) + } + req.Header.Add("Range", fmt.Sprintf("bytes=%d-%d", start, end-1)) + resp, err := c.httpClient.Do(req.WithContext(ctx)) + if err != nil { + return fmt.Errorf("failed to download archive: %w", err) + } + defer resp.Body.Close() + + if err := transport.CheckError(resp, http.StatusPartialContent); err != nil { + return fmt.Errorf("failed to download archive from %s (status: %s)", u.String(), resp.Status) + } + + c, err := io.ReadAll(io.LimitReader(resp.Body, end-start)) + if err != nil { + return fmt.Errorf("failed to read response body: %w", err) + } + chunks[i] = c + return nil + }) + } + err := g.Wait() + if err != nil { + return err + } + + content := bufio.NewReader(bytes.NewReader(bytes.Join(chunks, nil))) + d, s, err := v1.SHA256(content) + if err != nil { + return err + } + if d != digest { + return fmt.Errorf("digest mismatch: expected %s, got %s", digest, d) + } + if s != size { + return fmt.Errorf("size mismatch: expected %d, got %d", size, size) + } + + f, err := os.Create(path) + if err != nil { + return err + } + + _, err = io.Copy(f, content) + if err != nil { + return fmt.Errorf("error copying layer content: %s", err) + } + return nil +} + // extractLayer extracts the Layer to the path -func extractLayer(layer gcrv1.Layer, path string, layerType LayerType) error { +func extractLayer(layer v1.Layer, path string, layerType LayerType) error { var blob io.Reader blob, err := layer.Compressed() if err != nil { @@ -173,3 +341,40 @@ func isGzipBlob(buf *bufio.Reader) (bool, error) { } return bytes.Equal(b, gzipMagicHeader), nil } + +type resource interface { + Scheme() string + RegistryStr() string + Scope(string) string + + authn.Resource +} + +func makeHttpClient(ctx context.Context, target resource, o PullOptions) (*retryablehttp.Client, error) { + auth := o.auth + if o.keychain != nil { + kauth, err := o.keychain.Resolve(target) + if err != nil { + return nil, err + } + auth = kauth + } + + reg, ok := target.(name.Registry) + if !ok { + repo, ok := target.(name.Repository) + if !ok { + return nil, fmt.Errorf("unexpected resource: %T", target) + } + reg = repo.Registry + } + + tr, err := transport.NewWithContext(ctx, reg, auth, o.transport, []string{target.Scope(transport.PullScope)}) + if err != nil { + return nil, err + } + + h := retryablehttp.NewClient() + h.HTTPClient = &http.Client{Transport: tr} + return h, nil +} diff --git a/oci/client/pull_test.go b/oci/client/pull_test.go index 86795284d..b68dd15a5 100644 --- a/oci/client/pull_test.go +++ b/oci/client/pull_test.go @@ -41,6 +41,7 @@ func Test_PullAnyTarball(t *testing.T) { repo := "test-no-annotations" + randStringRunes(5) dst := fmt.Sprintf("%s/%s:%s", dockerReg, repo, tag) + fmt.Println("Pulling from:", dst) artifact := filepath.Join(t.TempDir(), "artifact.tgz") g.Expect(build(artifact, testDir, nil)).To(Succeed()) @@ -82,3 +83,23 @@ func Test_PullAnyTarball(t *testing.T) { g.Expect(extractTo + "/" + entry).To(Or(BeAnExistingFile(), BeADirectory())) } } + +func Test_PullLargeTarball(t *testing.T) { + g := NewWithT(t) + ctx := context.Background() + c := NewClient(DefaultOptions()) + dst := "vnp505/zephyr-7b-alpha:alpha" + extractTo := filepath.Join(t.TempDir(), "artifact") + m, err := c.Pull(ctx, dst, extractTo, WithPullLayerIndex(19)) + fmt.Println("Pulled from:", dst) + g.Expect(err).ToNot(HaveOccurred()) + g.Expect(err).ToNot(HaveOccurred()) + g.Expect(m).ToNot(BeNil()) + g.Expect(m.Annotations).To(BeEmpty()) + g.Expect(m.Created).To(BeEmpty()) + g.Expect(m.Revision).To(BeEmpty()) + g.Expect(m.Source).To(BeEmpty()) + g.Expect(m.URL).To(Equal(dst)) + g.Expect(m.Digest).ToNot(BeEmpty()) + g.Expect(extractTo).ToNot(BeEmpty()) +} diff --git a/oci/client/push_pull_test.go b/oci/client/push_pull_test.go index 3c68b2537..9d02f1015 100644 --- a/oci/client/push_pull_test.go +++ b/oci/client/push_pull_test.go @@ -305,6 +305,7 @@ func Test_Push_Pull(t *testing.T) { g.Expect(err).ToNot(HaveOccurred()) fileInfo, err := os.Stat(tt.sourcePath) + g.Expect(err).ToNot(HaveOccurred()) // if a directory was pushed, then the created file should be a gzipped archive if fileInfo.IsDir() { bufReader := bufio.NewReader(bytes.NewReader(got)) diff --git a/oci/go.mod b/oci/go.mod index e992b5755..50681ead3 100644 --- a/oci/go.mod +++ b/oci/go.mod @@ -21,9 +21,11 @@ require ( github.com/fluxcd/pkg/tar v0.4.0 github.com/fluxcd/pkg/version v0.2.2 github.com/google/go-containerregistry v0.18.0 + github.com/hashicorp/go-retryablehttp v0.7.5 github.com/onsi/gomega v1.31.1 github.com/phayes/freeport v0.0.0-20220201140144-74d24b5ae9f5 github.com/sirupsen/logrus v1.9.3 + golang.org/x/sync v0.6.0 sigs.k8s.io/controller-runtime v0.16.3 ) @@ -80,6 +82,7 @@ require ( github.com/gorilla/handlers v1.5.1 // indirect github.com/gorilla/mux v1.8.1 // indirect github.com/grpc-ecosystem/grpc-gateway/v2 v2.16.0 // indirect + github.com/hashicorp/go-cleanhttp v0.5.2 // indirect github.com/hashicorp/golang-lru/arc/v2 v2.0.5 // indirect github.com/hashicorp/golang-lru/v2 v2.0.5 // indirect github.com/imdario/mergo v0.3.15 // indirect @@ -130,7 +133,6 @@ require ( golang.org/x/exp v0.0.0-20220722155223-a9213eeb770e // indirect golang.org/x/net v0.20.0 // indirect golang.org/x/oauth2 v0.16.0 // indirect - golang.org/x/sync v0.6.0 // indirect golang.org/x/sys v0.16.0 // indirect golang.org/x/term v0.16.0 // indirect golang.org/x/text v0.14.0 // indirect diff --git a/oci/go.sum b/oci/go.sum index 87ecfa5d7..45aeee99f 100644 --- a/oci/go.sum +++ b/oci/go.sum @@ -155,6 +155,12 @@ github.com/gorilla/mux v1.8.1 h1:TuBL49tXwgrFYWhqrNgrUNEY92u81SPhu7sTdzQEiWY= github.com/gorilla/mux v1.8.1/go.mod h1:AKf9I4AEqPTmMytcMc0KkNouC66V3BtZ4qD5fmWSiMQ= github.com/grpc-ecosystem/grpc-gateway/v2 v2.16.0 h1:YBftPWNWd4WwGqtY2yeZL2ef8rHAxPBD8KFhJpmcqms= github.com/grpc-ecosystem/grpc-gateway/v2 v2.16.0/go.mod h1:YN5jB8ie0yfIUg6VvR9Kz84aCaG7AsGZnLjhHbUqwPg= +github.com/hashicorp/go-cleanhttp v0.5.2 h1:035FKYIWjmULyFRBKPs8TBQoi0x6d9G4xc9neXJWAZQ= +github.com/hashicorp/go-cleanhttp v0.5.2/go.mod h1:kO/YDlP8L1346E6Sodw+PrpBSV4/SoxCXGY6BqNFT48= +github.com/hashicorp/go-hclog v0.9.2 h1:CG6TE5H9/JXsFWJCfoIVpKFIkFe6ysEuHirp4DxCsHI= +github.com/hashicorp/go-hclog v0.9.2/go.mod h1:5CU+agLiy3J7N7QjHK5d05KxGsuXiQLrjA0H7acj2lQ= +github.com/hashicorp/go-retryablehttp v0.7.5 h1:bJj+Pj19UZMIweq/iie+1u5YCdGrnxCT9yvm0e+Nd5M= +github.com/hashicorp/go-retryablehttp v0.7.5/go.mod h1:Jy/gPYAdjqffZ/yFGCFV2doI5wjtH1ewM9u8iYVjtX8= github.com/hashicorp/golang-lru/arc/v2 v2.0.5 h1:l2zaLDubNhW4XO3LnliVj0GXO3+/CGNJAg1dcN2Fpfw= github.com/hashicorp/golang-lru/arc/v2 v2.0.5/go.mod h1:ny6zBSQZi2JxIeYcv7kt2sH2PXJtirBN7RDhRpxPkxU= github.com/hashicorp/golang-lru/v2 v2.0.5 h1:wW7h1TG88eUIJ2i69gaE3uNVtEPIagzhGvHgwfx2Vm4= From 8fc8f0b3e3b5c4541d3eb3602b3f583719d4c12d Mon Sep 17 00:00:00 2001 From: Soule BA Date: Mon, 4 Mar 2024 17:46:23 +0100 Subject: [PATCH 2/2] introduce a blob manager which handle chunk downloads Signed-off-by: Soule BA --- oci/client/client.go | 4 +- oci/client/download.go | 359 +++++++++++++++++++++++++++++++++++++++++ oci/client/pull.go | 212 ++++-------------------- 3 files changed, 392 insertions(+), 183 deletions(-) create mode 100644 oci/client/download.go diff --git a/oci/client/client.go b/oci/client/client.go index c816855a7..b3cd257a0 100644 --- a/oci/client/client.go +++ b/oci/client/client.go @@ -21,15 +21,13 @@ import ( "github.com/google/go-containerregistry/pkg/crane" "github.com/google/go-containerregistry/pkg/v1/remote" -"github.com/hashicorp/go-retryablehttp" "github.com/fluxcd/pkg/oci" ) // Client holds the options for accessing remote OCI registries. type Client struct { - options []crane.Option - httpClient *retryablehttp.Client + options []crane.Option } // NewClient returns an OCI client configured with the given crane options. diff --git a/oci/client/download.go b/oci/client/download.go new file mode 100644 index 000000000..6f882f684 --- /dev/null +++ b/oci/client/download.go @@ -0,0 +1,359 @@ +/* +Copyright 2024 The Flux authors + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package client + +import ( + "context" + "crypto/sha256" + "errors" + "fmt" + "io" + "net/http" + "net/url" + "os" + "syscall" + "time" + + "github.com/google/go-containerregistry/pkg/authn" + "github.com/google/go-containerregistry/pkg/name" + v1 "github.com/google/go-containerregistry/pkg/v1" + "github.com/google/go-containerregistry/pkg/v1/remote" + "github.com/google/go-containerregistry/pkg/v1/remote/transport" + "github.com/hashicorp/go-retryablehttp" + "golang.org/x/sync/errgroup" +) + +const ( + minChunkSize = 100 * 1024 * 1024 // 100MB + maxChunkSize = 1 << 30 // 1GB + defaultNumberOfChunks = 50 +) + +var ( + // errRangeRequestNotSupported is returned when the registry does not support range requests. + errRangeRequestNotSupported = fmt.Errorf("range requests are not supported by the registry") + errCopyFailed = errors.New("copy failed") +) + +var ( + retries = 3 + defaultRetryBackoff = remote.Backoff{ + Duration: 1.0 * time.Second, + Factor: 3.0, + Jitter: 0.1, + Steps: retries, + } +) + +type downloadOption func(*downloadOptions) + +type downloadOptions struct { + transport http.RoundTripper + auth authn.Authenticator + keychain authn.Keychain + numberOfChunks int +} + +type blobManager struct { + name name.Reference + c *retryablehttp.Client + layer v1.Layer + path string + digest v1.Hash + size int64 + downloadOptions +} + +func withTransport(t http.RoundTripper) downloadOption { + return func(o *downloadOptions) { + o.transport = t + } +} + +func withAuth(auth authn.Authenticator) downloadOption { + return func(o *downloadOptions) { + o.auth = auth + } +} + +func withKeychain(k authn.Keychain) downloadOption { + return func(o *downloadOptions) { + o.keychain = k + } +} + +func withNumberOfChunks(n int) downloadOption { + return func(o *downloadOptions) { + o.numberOfChunks = n + } +} + +type chunk struct { + n int + offset int64 + size int64 + writeCounter +} + +func makeChunk(n int, offset, size int64) *chunk { + return &chunk{ + n: n, + offset: offset, + size: size, + writeCounter: writeCounter{}, + } +} + +// newDownloader returns a new blobManager with the given options. +func newDownloader(name name.Reference, path string, layer v1.Layer, opts ...downloadOption) *blobManager { + o := &downloadOptions{ + numberOfChunks: defaultNumberOfChunks, + keychain: authn.DefaultKeychain, + transport: remote.DefaultTransport.(*http.Transport).Clone(), + } + d := &blobManager{ + layer: layer, + name: name, + path: path, + downloadOptions: *o, + } + for _, opt := range opts { + opt(&d.downloadOptions) + } + + return d +} + +func (d *blobManager) download(ctx context.Context) error { + digest, err := d.layer.Digest() + if err != nil { + return fmt.Errorf("failed to get layer digest: %w", err) + } + d.digest = digest + + size, err := d.layer.Size() + if err != nil { + return fmt.Errorf("failed to get layer size: %w", err) + } + d.size = size + + if d.c == nil { + h, err := makeHttpClient(ctx, d.name.Context(), &d.downloadOptions) + if err != nil { + return fmt.Errorf("failed to create HTTP client: %w", err) + } + d.c = h + } + + ok, err := d.isRangeRequestEnabled(ctx) + if err != nil { + return fmt.Errorf("failed to check range request support: %w", err) + } + + if !ok { + return errRangeRequestNotSupported + } + + if err := d.downloadChunks(ctx); err != nil { + return fmt.Errorf("failed to download layer in chunks: %w", err) + } + + if err := d.verifyDigest(); err != nil { + return fmt.Errorf("failed to verify layer digest: %w", err) + } + + return nil +} + +func (d *blobManager) downloadChunks(ctx context.Context) error { + u := makeUrl(d.name, d.digest) + + file, err := os.OpenFile(d.path+".tmp", os.O_CREATE|os.O_WRONLY, 0644) + if err != nil { + return fmt.Errorf("failed to create layer file: %w", err) + } + defer file.Close() + + chunkSize := d.size / int64(d.numberOfChunks) + if chunkSize < minChunkSize { + chunkSize = minChunkSize + } else if chunkSize > maxChunkSize { + chunkSize = maxChunkSize + } + + var ( + chunks []*chunk + n int + ) + + for offset := int64(0); offset < d.size; offset += chunkSize { + if offset+chunkSize > d.size { + chunkSize = d.size - offset + } + chunk := makeChunk(n, offset, chunkSize) + chunks = append(chunks, chunk) + n++ + } + + g, ctx := errgroup.WithContext(ctx) + g.SetLimit(d.numberOfChunks) + for _, chunk := range chunks { + chunk := chunk + g.Go(func() error { + b := defaultRetryBackoff + for i := 0; i < retries; i++ { + w := io.NewOffsetWriter(file, chunk.offset) + err := chunk.download(ctx, d.c, w, u) + switch { + case errors.Is(err, context.Canceled), errors.Is(err, syscall.ENOSPC): + return err + case errors.Is(err, errCopyFailed): + time.Sleep(b.Step()) + continue + default: + return nil + } + } + return fmt.Errorf("failed to download chunk %d: %w", n, err) + }) + } + + err = g.Wait() + if err != nil { + return fmt.Errorf("failed to download layer in chunks: %w", err) + } + + if err := os.Rename(file.Name(), d.path); err != nil { + return err + } + + return nil + +} + +func (c *chunk) download(ctx context.Context, client *retryablehttp.Client, w io.Writer, u url.URL) error { + req, err := retryablehttp.NewRequest(http.MethodGet, u.String(), nil) + if err != nil { + return err + } + + req.Header.Set("Range", fmt.Sprintf("bytes=%d-%d", c.offset, c.offset+c.size-1)) + resp, err := client.Do(req.WithContext(ctx)) + if err != nil { + return err + } + + if err := transport.CheckError(resp, http.StatusPartialContent); err != nil { + return err + } + + _, err = io.Copy(w, io.TeeReader(resp.Body, &c.writeCounter)) + if err != nil && !errors.Is(err, context.Canceled) && !errors.Is(err, io.ErrUnexpectedEOF) { + // TODO: if the download was interrupted, we can resume it + return fmt.Errorf("failed to download chunk %d: %w", c.n, err) + } + + return err +} + +func (d *blobManager) isRangeRequestEnabled(ctx context.Context) (bool, error) { + u := makeUrl(d.name, d.digest) + req, err := retryablehttp.NewRequest(http.MethodHead, u.String(), nil) + if err != nil { + return false, err + } + + resp, err := d.c.Do(req.WithContext(ctx)) + if err != nil { + return false, err + } + + if err := transport.CheckError(resp, http.StatusOK); err != nil { + return false, err + } + + if rangeUnit := resp.Header.Get("Accept-Ranges"); rangeUnit == "bytes" { + return true, nil + } + + return false, nil +} + +func (d *blobManager) verifyDigest() error { + f, err := os.Open(d.path) + if err != nil { + return fmt.Errorf("failed to open layer file: %w", err) + } + defer f.Close() + + h := sha256.New() + _, err = io.Copy(h, f) + if err != nil { + return fmt.Errorf("failed to hash layer: %w", err) + } + + newDigest := h.Sum(nil) + if d.digest.String() != fmt.Sprintf("sha256:%x", newDigest) { + return fmt.Errorf("layer digest does not match: %s != sha256:%x", d.digest.String(), newDigest) + } + return nil +} + +func makeUrl(name name.Reference, digest v1.Hash) url.URL { + return url.URL{ + Scheme: name.Context().Scheme(), + Host: name.Context().RegistryStr(), + Path: fmt.Sprintf("/v2/%s/blobs/%s", name.Context().RepositoryStr(), digest.String()), + } +} + +type resource interface { + Scheme() string + RegistryStr() string + Scope(string) string + + authn.Resource +} + +func makeHttpClient(ctx context.Context, target resource, o *downloadOptions) (*retryablehttp.Client, error) { + auth := o.auth + if o.keychain != nil { + kauth, err := o.keychain.Resolve(target) + if err != nil { + return nil, err + } + auth = kauth + } + + reg, ok := target.(name.Registry) + if !ok { + repo, ok := target.(name.Repository) + if !ok { + return nil, fmt.Errorf("unexpected resource: %T", target) + } + reg = repo.Registry + } + + tr, err := transport.NewWithContext(ctx, reg, auth, o.transport, []string{target.Scope(transport.PullScope)}) + if err != nil { + return nil, err + } + + h := retryablehttp.NewClient() + h.HTTPClient = &http.Client{Transport: tr} + return h, nil +} diff --git a/oci/client/pull.go b/oci/client/pull.go index b49892a7f..5ed76515d 100644 --- a/oci/client/pull.go +++ b/oci/client/pull.go @@ -20,31 +20,28 @@ import ( "bufio" "bytes" "context" + "errors" "fmt" "io" "net/http" - "net/url" "os" "github.com/fluxcd/pkg/tar" "github.com/google/go-containerregistry/pkg/authn" "github.com/google/go-containerregistry/pkg/crane" "github.com/google/go-containerregistry/pkg/name" - "github.com/hashicorp/go-retryablehttp" v1 "github.com/google/go-containerregistry/pkg/v1" "github.com/google/go-containerregistry/pkg/v1/remote" - "github.com/google/go-containerregistry/pkg/v1/remote/transport" - "golang.org/x/sync/errgroup" ) -const ( - // thresholdForConcurrentPull is the maximum size of a layer to be extracted in one go. - // If the layer is larger than this, it will be downloaded in chunks. - thresholdForConcurrentPull = 100 * 1024 * 1024 // 100MB - // maxConcurrentPulls is the maximum number of concurrent downloads. - maxConcurrentPulls = 10 -) +// const ( +// // thresholdForConcurrentPull is the maximum size of a layer to be extracted in one go. +// // If the layer is larger than this, it will be downloaded in chunks. +// thresholdForConcurrentPull = 100 * 1024 * 1024 // 100MB +// // maxConcurrentPulls is the maximum number of concurrent downloads. +// maxConcurrentPulls = 10 +// ) var ( // gzipMagicHeader are bytes found at the start of gzip files @@ -54,12 +51,11 @@ var ( // PullOptions contains options for pulling a layer. type PullOptions struct { - layerIndex int - layerType LayerType - transport http.RoundTripper - auth authn.Authenticator - keychain authn.Keychain - concurrency int + layerIndex int + layerType LayerType + transport http.RoundTripper + auth authn.Authenticator + keychain authn.Keychain } // PullOption is a function for configuring PullOptions. @@ -85,9 +81,15 @@ func WithTransport(t http.RoundTripper) PullOption { } } -func WithConcurrency(c int) PullOption { +func WithAuth(auth authn.Authenticator) PullOption { + return func(o *PullOptions) { + o.auth = auth + } +} + +func WithKeychain(k authn.Keychain) PullOption { return func(o *PullOptions) { - o.concurrency = c + o.keychain = k } } @@ -103,10 +105,6 @@ func (c *Client) Pull(ctx context.Context, urlString, outPath string, opts ...Pu opt(o) } - if o.concurrency == 0 || o.concurrency > maxConcurrentPulls { - o.concurrency = maxConcurrentPulls - } - if o.transport == nil { transport := remote.DefaultTransport.(*http.Transport).Clone() o.transport = transport @@ -117,14 +115,6 @@ func (c *Client) Pull(ctx context.Context, urlString, outPath string, opts ...Pu return nil, fmt.Errorf("invalid URL: %w", err) } - if c.httpClient == nil { - h, err := makeHttpClient(ctx, ref.Context(), *o) - if err != nil { - return nil, err - } - c.httpClient = h - } - img, err := crane.Pull(urlString, c.optionsWithContext(ctx)...) if err != nil { return nil, err @@ -162,126 +152,25 @@ func (c *Client) Pull(ctx context.Context, urlString, outPath string, opts ...Pu return nil, fmt.Errorf("failed to get layer size: %w", err) } - if size > thresholdForConcurrentPull { - digest, err := layers[o.layerIndex].Digest() - if err != nil { - return nil, fmt.Errorf("parsing digest failed: %w", err) - } - u := url.URL{ - Scheme: ref.Context().Scheme(), - Host: ref.Context().RegistryStr(), - Path: fmt.Sprintf("/v2/%s/blobs/%s", ref.Context().RepositoryStr(), digest.String()), + if size > minChunkSize { + manager := newDownloader(ref, outPath, layers[o.layerIndex], + withTransport(o.transport), withKeychain(o.keychain), withAuth(o.auth)) + err = manager.download(ctx) + if err != nil && !errors.Is(err, errRangeRequestNotSupported) { + return nil, fmt.Errorf("failed to download layer: %w", err) } - ok, err := c.IsRangeRequestEnabled(ctx, u) + } + + if size <= minChunkSize || errors.Is(err, errRangeRequestNotSupported) { + err = extractLayer(layers[o.layerIndex], outPath, o.layerType) if err != nil { - return nil, fmt.Errorf("failed to check range request support: %w", err) - } - if ok { - err = c.concurrentExtractLayer(ctx, u, layers[o.layerIndex], outPath, digest, size, o.concurrency) - if err != nil { - return nil, err - } - return meta, nil + return nil, err } } - err = extractLayer(layers[o.layerIndex], outPath, o.layerType) - if err != nil { - return nil, err - } return meta, nil } -// TO DO: handle authentication handle using keychain for authentication -func (c *Client) IsRangeRequestEnabled(ctx context.Context, u url.URL) (bool, error) { - req, err := retryablehttp.NewRequest(http.MethodHead, u.String(), nil) - if err != nil { - return false, err - } - - resp, err := c.httpClient.Do(req.WithContext(ctx)) - if err != nil { - return false, err - } - - if err := transport.CheckError(resp, http.StatusOK); err != nil { - return false, err - } - - if rangeUnit := resp.Header.Get("Accept-Ranges"); rangeUnit == "bytes" { - return true, nil - } - for k, v := range resp.Header { - fmt.Printf("Header: %s, Value: %s\n", k, v) - } - return false, nil -} - -func (c *Client) concurrentExtractLayer(ctx context.Context, u url.URL, layer v1.Layer, path string, digest v1.Hash, size int64, concurrency int) error { - chunkSize := size / int64(concurrency) - chunks := make([][]byte, concurrency+1) - diff := size % int64(concurrency) - - g, ctx := errgroup.WithContext(ctx) - for i := 0; i < concurrency; i++ { - i := i - g.Go(func() (err error) { - start, end := int64(i)*chunkSize, int64(i+1)*chunkSize - if i == concurrency-1 { - end += diff - } - req, err := retryablehttp.NewRequest(http.MethodGet, u.String(), nil) - if err != nil { - return fmt.Errorf("failed to create a new request: %w", err) - } - req.Header.Add("Range", fmt.Sprintf("bytes=%d-%d", start, end-1)) - resp, err := c.httpClient.Do(req.WithContext(ctx)) - if err != nil { - return fmt.Errorf("failed to download archive: %w", err) - } - defer resp.Body.Close() - - if err := transport.CheckError(resp, http.StatusPartialContent); err != nil { - return fmt.Errorf("failed to download archive from %s (status: %s)", u.String(), resp.Status) - } - - c, err := io.ReadAll(io.LimitReader(resp.Body, end-start)) - if err != nil { - return fmt.Errorf("failed to read response body: %w", err) - } - chunks[i] = c - return nil - }) - } - err := g.Wait() - if err != nil { - return err - } - - content := bufio.NewReader(bytes.NewReader(bytes.Join(chunks, nil))) - d, s, err := v1.SHA256(content) - if err != nil { - return err - } - if d != digest { - return fmt.Errorf("digest mismatch: expected %s, got %s", digest, d) - } - if s != size { - return fmt.Errorf("size mismatch: expected %d, got %d", size, size) - } - - f, err := os.Create(path) - if err != nil { - return err - } - - _, err = io.Copy(f, content) - if err != nil { - return fmt.Errorf("error copying layer content: %s", err) - } - return nil -} - // extractLayer extracts the Layer to the path func extractLayer(layer v1.Layer, path string, layerType LayerType) error { var blob io.Reader @@ -341,40 +230,3 @@ func isGzipBlob(buf *bufio.Reader) (bool, error) { } return bytes.Equal(b, gzipMagicHeader), nil } - -type resource interface { - Scheme() string - RegistryStr() string - Scope(string) string - - authn.Resource -} - -func makeHttpClient(ctx context.Context, target resource, o PullOptions) (*retryablehttp.Client, error) { - auth := o.auth - if o.keychain != nil { - kauth, err := o.keychain.Resolve(target) - if err != nil { - return nil, err - } - auth = kauth - } - - reg, ok := target.(name.Registry) - if !ok { - repo, ok := target.(name.Repository) - if !ok { - return nil, fmt.Errorf("unexpected resource: %T", target) - } - reg = repo.Registry - } - - tr, err := transport.NewWithContext(ctx, reg, auth, o.transport, []string{target.Scope(transport.PullScope)}) - if err != nil { - return nil, err - } - - h := retryablehttp.NewClient() - h.HTTPClient = &http.Client{Transport: tr} - return h, nil -}