// Copyright 2024 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//      http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package transfermanager

import (
	"bytes"
	"context"
	"errors"
	"os"
	"path/filepath"
	"strings"
	"sync"
	"testing"

	"cloud.google.com/go/storage"
)

func TestWaitAndClose(t *testing.T) {
	t.Parallel()
	d, err := NewDownloader(nil)
	if err != nil {
		t.Fatalf("NewDownloader: %v", err)
	}

	if _, err := d.WaitAndClose(); err != nil {
		t.Fatalf("WaitAndClose: %v", err)
	}

	expectedErr := "transfermanager: Downloader used after WaitAndClose was called"
	err = d.DownloadObject(context.Background(), &DownloadObjectInput{})
	if err == nil {
		t.Fatalf("d.DownloadObject err was nil, should be %q", expectedErr)
	}
	if !strings.Contains(err.Error(), expectedErr) {
		t.Errorf("expected err %q, got: %v", expectedErr, err.Error())
	}
}

func TestNumShards(t *testing.T) {
	t.Parallel()
	for _, test := range []struct {
		desc       string
		objRange   *DownloadRange
		objSize    int64
		partSize   int64
		transcoded bool
		want       int
	}{
		{
			desc:     "nil range",
			objSize:  100,
			partSize: 1000,
			want:     1,
		},
		{
			desc:     "nil - object equal to partSize",
			objSize:  100,
			partSize: 100,
			want:     1,
		},
		{
			desc:     "nil - object larger than partSize",
			objSize:  100,
			partSize: 10,
			want:     10,
		},
		{
			desc: "full object smaller than partSize",
			objRange: &DownloadRange{
				Length: 100,
			},
			objSize:  100,
			partSize: 101,
			want:     1,
		},
		{
			desc: "full object equal to partSize",
			objRange: &DownloadRange{
				Length: 100,
			},
			objSize:  100,
			partSize: 100,
			want:     1,
		},
		{
			desc: "full object larger than partSize",
			objRange: &DownloadRange{
				Length: 100,
			},
			objSize:  100,
			partSize: 99,
			want:     2,
		},
		{
			desc: "partial object smaller than partSize",
			objRange: &DownloadRange{
				Length: 50,
			},
			objSize:  100,
			partSize: 1000,
			want:     1,
		},
		{
			desc: "full object larger than partSize",
			objRange: &DownloadRange{
				Length: 5000,
			},
			objSize:  5001,
			partSize: 1000,
			want:     5,
		},
		{
			desc: "full object larger than partSize - off by one check",
			objRange: &DownloadRange{
				Length: 5001,
			},
			objSize:  5001,
			partSize: 1000,
			want:     6,
		},
		{
			desc: "length larger than object size",
			objRange: &DownloadRange{
				Length: 17000,
			},
			objSize:  5000,
			partSize: 1000,
			want:     5,
		},
		{
			desc: "negative length",
			objRange: &DownloadRange{
				Length: -1,
			},
			objSize:  5000,
			partSize: 1000,
			want:     5,
		},
		{
			desc: "offset object smaller than partSize",
			objRange: &DownloadRange{
				Offset: 50,
				Length: 99,
			},
			objSize:  100,
			partSize: 1000,
			want:     1,
		},
		{
			desc: "offset object larger than partSize",
			objRange: &DownloadRange{
				Offset: 1000,
				Length: 1999,
			},
			objSize:  2000,
			partSize: 100,
			want:     10,
		},
		{
			desc: "offset object larger than partSize - length larger than objSize",
			objRange: &DownloadRange{
				Offset: 1000,
				Length: 10000,
			},
			objSize:  2001,
			partSize: 100,
			want:     11,
		},
		{
			desc: "offset object larger than partSize - length larger than objSize",
			objRange: &DownloadRange{
				Offset: 1000,
				Length: 10000,
			},
			objSize:  2001,
			partSize: 100,
			want:     11,
		},
		{
			desc: "negative offset smaller than partSize",
			objRange: &DownloadRange{
				Offset: -5,
				Length: -1,
			},
			objSize:  1024 * 1024 * 1024 * 10,
			partSize: 100,
			want:     1,
		},
		{
			desc: "negative offset larger than partSize",
			objRange: &DownloadRange{
				Offset: -1000,
				Length: -1,
			},
			objSize:  2000,
			partSize: 100,
			want:     1,
		},
		{
			desc:       "transcoded",
			objSize:    2000,
			partSize:   100,
			transcoded: true,
			want:       1,
		},
	} {
		t.Run(test.desc, func(t *testing.T) {
			attrs := &storage.ReaderObjectAttrs{
				Size: test.objSize,
			}

			if test.transcoded {
				attrs.ContentEncoding = "gzip"
			}

			got := numShards(attrs, test.objRange, test.partSize)

			if got != test.want {
				t.Errorf("numShards incorrect; expect object to be divided into %d shards, got %d", test.want, got)
			}
		})
	}
}

func TestIsSubPath(t *testing.T) {
	t.Parallel()
	sep := string(filepath.Separator)

	// Create a temporary directory to work with relative paths.
	tempDir := t.TempDir()

	testCases := []struct {
		name           string
		localDirectory string
		filePath       string
		wantIsSub      bool
		wantErrMsg     string
	}{
		{
			name:           "filePath is a child",
			localDirectory: "/tmp/foo",
			filePath:       "/tmp/foo/bar",
			wantIsSub:      true,
		},
		{
			name:           "filePath is a nested child",
			localDirectory: "/tmp/foo",
			filePath:       "/tmp/foo/bar/baz",
			wantIsSub:      true,
		},
		{
			name:           "filePath is the same as localDirectory",
			localDirectory: "/tmp/foo",
			filePath:       "/tmp/foo",
			wantIsSub:      true,
		},
		{
			name:           "filePath is a sibling",
			localDirectory: "/tmp/foo",
			filePath:       "/tmp/bar",
			wantIsSub:      false,
		},
		{
			name:           "filePath is a parent",
			localDirectory: "/tmp/foo",
			filePath:       "/tmp",
			wantIsSub:      false,
		},
		{
			name:           "directory traversal attempt",
			localDirectory: "/tmp/foo",
			filePath:       "/tmp/foo/../bar", // resolves to /tmp/bar
			wantIsSub:      false,
		},
		{
			name:           "deeper directory traversal attempt",
			localDirectory: "/tmp/foo",
			filePath:       "/tmp/foo/bar/../../baz", // resolves to /tmp/baz
			wantIsSub:      false,
		},
		{
			name:           "relative paths - valid",
			localDirectory: tempDir,
			filePath:       filepath.Join(tempDir, "bar"),
			wantIsSub:      true,
		},
		{
			name:           "relative paths - traversal",
			localDirectory: tempDir,
			filePath:       filepath.Join(tempDir, "..", "bar"),
			wantIsSub:      false,
		},
		{
			name:           "relative path is just ..",
			localDirectory: "foo",
			filePath:       "foo" + sep + ".." + sep + "..",
			wantIsSub:      false,
		},
		{
			name:           "IsSubPath returns error when base dir is changed",
			localDirectory: "foo",
			filePath:       "bar",
			wantErrMsg:     "no such file or directory",
		},
	}

	for _, tc := range testCases {
		t.Run(tc.name, func(t *testing.T) {
			origWd, _ := os.Getwd()
			t.Cleanup(func() {
				os.Chdir(origWd)
			})
			wantErr := (tc.wantErrMsg != "")
			// induce filepath.Abs() error
			if wantErr {
				dir, _ := os.MkdirTemp("", "")
				os.Chdir(dir)
				os.RemoveAll(dir)
			}
			isSub, err := isSubPath(tc.localDirectory, tc.filePath)

			if (err != nil) != wantErr {
				t.Fatalf("isSubPath() error = %v, wantErr %v", err, tc.wantErrMsg)
			}
			if wantErr && !strings.Contains(err.Error(), tc.wantErrMsg) {
				t.Errorf("isSubPath() error = %s, want err containing %s", err.Error(), tc.wantErrMsg)
				return
			}
			if isSub != tc.wantIsSub {
				t.Errorf("isSubPath() = %v, want %v", isSub, tc.wantIsSub)
			}
		})
	}
}

func TestCalculateRange(t *testing.T) {
	t.Parallel()
	for _, test := range []struct {
		desc     string
		objRange *DownloadRange
		partSize int64
		shard    int
		want     DownloadRange
	}{
		{
			desc:     "nil range - first shard",
			partSize: 1000,
			shard:    0,
			want: DownloadRange{
				Length: 1000,
			},
		},
		{
			desc:     "nil range",
			partSize: 1001,
			shard:    3,
			want: DownloadRange{
				Offset: 3003,
				Length: 1001,
			},
		},
		{
			desc: "first shard length smaller than partSize",
			objRange: &DownloadRange{
				Length: 99,
			},
			partSize: 1000,
			shard:    0,
			want: DownloadRange{
				Length: 99,
			},
		},
		{
			desc: "second shard",
			objRange: &DownloadRange{
				Length: 4999,
			},
			partSize: 1000,
			shard:    1,
			want: DownloadRange{
				Offset: 1000,
				Length: 1000,
			},
		},
		{
			desc: "last shard",
			objRange: &DownloadRange{
				Length: 5000,
			},
			partSize: 1000,
			shard:    4,
			want: DownloadRange{
				Offset: 4000,
				Length: 1000,
			},
		},
		{
			desc: "last shard",
			objRange: &DownloadRange{
				Length: 5001,
			},
			partSize: 1000,
			shard:    5,
			want: DownloadRange{
				Offset: 5000,
				Length: 1,
			},
		},
		{
			desc: "single shard with offset",
			objRange: &DownloadRange{
				Offset: 10,
				Length: 99,
			},
			partSize: 1000,
			shard:    0,
			want: DownloadRange{
				Offset: 10,
				Length: 99,
			},
		},
		{
			desc: "second shard with offset",
			objRange: &DownloadRange{
				Offset: 100,
				Length: 500,
			},
			partSize: 100,
			shard:    1,
			want: DownloadRange{
				Offset: 200,
				Length: 100,
			},
		},
		{
			desc: "off by one",
			objRange: &DownloadRange{
				Offset: 101,
				Length: 500,
			},
			partSize: 100,
			shard:    2,
			want: DownloadRange{
				Offset: 301,
				Length: 100,
			},
		},
		{
			desc: "last shard",
			objRange: &DownloadRange{
				Offset: 1,
				Length: 5000,
			},
			partSize: 1000,
			shard:    4,
			want: DownloadRange{
				Offset: 4001,
				Length: 1000,
			},
		},
		{
			desc: "sharding turned off",
			objRange: &DownloadRange{
				Offset: 1024 * 1024 * 1024 * 1024 / 2,
				Length: 1024 * 1024 * 1024 * 1024,
			},
			partSize: 0,
			shard:    0,
			want: DownloadRange{
				Offset: 1024 * 1024 * 1024 * 1024 / 2,
				Length: 1024 * 1024 * 1024 * 1024,
			},
		},
		{
			desc: "large object",
			objRange: &DownloadRange{
				Offset: 1024 * 1024 * 1024 * 1024 / 2,
				Length: 1024 * 1024 * 1024 * 1024, // 1TiB
			},
			partSize: 1024 * 1024 * 1024, // 1 Gib
			shard:    1024/2 - 1,         // last shard
			want: DownloadRange{
				Offset: 1024*1024*1024*1024 - 1024*1024*1024,
				Length: 1024 * 1024 * 1024,
			},
		},
	} {
		t.Run(test.desc, func(t *testing.T) {
			got := shardRange(test.objRange, test.partSize, test.shard)

			if got != test.want {
				t.Errorf("want %v got %v", test.want, got)
			}
		})
	}
}

// This tests that gather shards works as expected and cancels other shards
// without error after it encounters an error.
func TestGatherShards(t *testing.T) {
	ctx, cancelCtx := context.WithCancelCause(context.Background())

	// Start a downloader.
	d, err := NewDownloader(nil, WithWorkers(2), WithCallbacks())
	if err != nil {
		t.Fatalf("NewDownloader: %v", err)
	}

	// Test that gatherShards finishes without error.
	object := "obj1"
	shards := 4
	downloadRange := &DownloadRange{
		Offset: 20,
		Length: 120,
	}
	firstOut := &DownloadOutput{Object: object, Range: &DownloadRange{Offset: 20, Length: 30}, shard: 0}
	outChan := make(chan *DownloadOutput, shards)
	outs := []*DownloadOutput{
		{Object: object, Range: &DownloadRange{Offset: 50, Length: 30}, shard: 1},
		{Object: object, Range: &DownloadRange{Offset: 80, Length: 30}, shard: 2},
		{Object: object, Range: &DownloadRange{Offset: 110, Length: 30}, shard: 3},
	}

	in := &DownloadObjectInput{
		Callback: func(o *DownloadOutput) {
			if o.Err != nil {
				t.Errorf("unexpected error in DownloadOutput: %v", o.Err)
			}
			if o.Range != downloadRange {
				t.Errorf("mismatching download range, got: %v, want: %v", o.Range, downloadRange)
			}
			if o.Object != object {
				t.Errorf("mismatching object names, got: %v, want: %v", o.Object, object)
			}
		},
		ctx:          ctx,
		cancelCtx:    cancelCtx,
		shardOutputs: outChan,
		Range:        downloadRange,
	}

	var wg sync.WaitGroup
	wg.Add(1)
	d.downloadsInProgress.Add(1)

	go func() {
		d.gatherShards(in, firstOut, outChan, shards, 0)
		wg.Done()
	}()

	for _, o := range outs {
		outChan <- o
	}

	wg.Wait()

	// Test that an error will cancel remaining pieces correctly.
	shardErr := errors.New("some error")

	in.Callback = func(o *DownloadOutput) {
		// Error returned should wrap the original error.
		if !errors.Is(o.Err, shardErr) {
			t.Errorf("error in DownloadOutput should wrap error %q; intead got: %v", shardErr, o.Err)
		}
		// Error returned should not wrap nor contain the sentinel error.
		if errors.Is(o.Err, errCancelAllShards) || strings.Contains(o.Err.Error(), errCancelAllShards.Error()) {
			t.Errorf("error in DownloadOutput should not contain error %q; got: %v", errCancelAllShards, o.Err)
		}
		if o.Range != downloadRange {
			t.Errorf("mismatching download range, got: %v, want: %v", o.Range, downloadRange)
		}
		if o.Object != object {
			t.Errorf("mismatching object names, got: %v, want: %v", o.Object, object)
		}
	}

	wg.Add(1)
	d.downloadsInProgress.Add(1)

	go func() {
		d.gatherShards(in, firstOut, outChan, shards, 0)
		wg.Done()
	}()

	// Send a successfull shard, an errored shard, and then a cancelled shard.
	outs[1].Err = shardErr
	outs[2].Err = context.Canceled
	for _, o := range outs {
		outChan <- o
	}

	// Check that the context was cancelled with the sentinel error.
	_, ok := <-in.ctx.Done()
	if ok {
		t.Error("context was not cancelled")
	}

	if ctxErr := context.Cause(in.ctx); !errors.Is(ctxErr, errCancelAllShards) {
		t.Errorf("context.Cause: error should wrap %q; intead got: %v", errCancelAllShards, ctxErr)
	}

	wg.Wait()

	// Check that the overall error returned also wraps only the proper error.
	_, err = d.WaitAndClose()
	if !errors.Is(err, shardErr) {
		t.Errorf("error in DownloadOutput should wrap error %q; intead got: %v", shardErr, err)
	}
	if errors.Is(err, errCancelAllShards) || strings.Contains(err.Error(), errCancelAllShards.Error()) {
		t.Errorf("error in DownloadOutput should not contain error %q; got: %v", errCancelAllShards, err)
	}
}

func TestCalculateCRC32C(t *testing.T) {
	t.Parallel()
	for _, test := range []struct {
		desc   string
		pieces []string
	}{
		{
			desc:   "equal sized pieces",
			pieces: []string{"he", "ll", "o ", "wo", "rl", "d!"},
		},
		{
			desc:   "uneven pieces",
			pieces: []string{"hello", " ", "world!"},
		},
		{
			desc: "large pieces",
			pieces: []string{string(bytes.Repeat([]byte("a"), 1024*1024*32)),
				string(bytes.Repeat([]byte("b"), 1024*1024*32)),
				string(bytes.Repeat([]byte("c"), 1024*1024*32)),
			},
		},
	} {
		t.Run(test.desc, func(t *testing.T) {
			initialChecksum := crc32c([]byte(test.pieces[0]))

			remainingChecksums := make([]crc32cPiece, len(test.pieces)-1)
			for i, piece := range test.pieces[1:] {
				remainingChecksums[i] = crc32cPiece{sum: crc32c([]byte(piece)), length: int64(len(piece))}
			}

			got := joinCRC32C(initialChecksum, remainingChecksums)
			want := crc32c([]byte(strings.Join(test.pieces, "")))

			if got != want {
				t.Errorf("crc32c not calculated correctly - want %v, got %v", want, got)
			}
		})
	}
}
