diff --git a/vendor/github.com/hashicorp/go-getter/get.go b/vendor/github.com/hashicorp/go-getter/get.go index c5b6dd453..e6053d934 100644 --- a/vendor/github.com/hashicorp/go-getter/get.go +++ b/vendor/github.com/hashicorp/go-getter/get.go @@ -63,7 +63,7 @@ func init() { "file": new(FileGetter), "git": new(GitGetter), "hg": new(HgGetter), - // "s3": new(S3Getter), + "s3": new(S3Getter), "http": httpGetter, "https": httpGetter, } diff --git a/vendor/github.com/hashicorp/go-getter/get_s3.go b/vendor/github.com/hashicorp/go-getter/get_s3.go new file mode 100644 index 000000000..ebb321741 --- /dev/null +++ b/vendor/github.com/hashicorp/go-getter/get_s3.go @@ -0,0 +1,270 @@ +package getter + +import ( + "fmt" + "io" + "net/url" + "os" + "path/filepath" + "strings" + + "github.com/aws/aws-sdk-go/aws" + "github.com/aws/aws-sdk-go/aws/credentials" + "github.com/aws/aws-sdk-go/aws/credentials/ec2rolecreds" + "github.com/aws/aws-sdk-go/aws/ec2metadata" + "github.com/aws/aws-sdk-go/aws/session" + "github.com/aws/aws-sdk-go/service/s3" +) + +// S3Getter is a Getter implementation that will download a module from +// a S3 bucket. +type S3Getter struct{} + +func (g *S3Getter) ClientMode(u *url.URL) (ClientMode, error) { + // Parse URL + region, bucket, path, _, creds, err := g.parseUrl(u) + if err != nil { + return 0, err + } + + // Create client config + config := g.getAWSConfig(region, u, creds) + sess := session.New(config) + client := s3.New(sess) + + // List the object(s) at the given prefix + req := &s3.ListObjectsInput{ + Bucket: aws.String(bucket), + Prefix: aws.String(path), + } + resp, err := client.ListObjects(req) + if err != nil { + return 0, err + } + + for _, o := range resp.Contents { + // Use file mode on exact match. + if *o.Key == path { + return ClientModeFile, nil + } + + // Use dir mode if child keys are found. + if strings.HasPrefix(*o.Key, path+"/") { + return ClientModeDir, nil + } + } + + // There was no match, so just return file mode. The download is going + // to fail but we will let S3 return the proper error later. + return ClientModeFile, nil +} + +func (g *S3Getter) Get(dst string, u *url.URL) error { + // Parse URL + region, bucket, path, _, creds, err := g.parseUrl(u) + if err != nil { + return err + } + + // Remove destination if it already exists + _, err = os.Stat(dst) + if err != nil && !os.IsNotExist(err) { + return err + } + + if err == nil { + // Remove the destination + if err := os.RemoveAll(dst); err != nil { + return err + } + } + + // Create all the parent directories + if err := os.MkdirAll(filepath.Dir(dst), 0755); err != nil { + return err + } + + config := g.getAWSConfig(region, u, creds) + sess := session.New(config) + client := s3.New(sess) + + // List files in path, keep listing until no more objects are found + lastMarker := "" + hasMore := true + for hasMore { + req := &s3.ListObjectsInput{ + Bucket: aws.String(bucket), + Prefix: aws.String(path), + } + if lastMarker != "" { + req.Marker = aws.String(lastMarker) + } + + resp, err := client.ListObjects(req) + if err != nil { + return err + } + + hasMore = aws.BoolValue(resp.IsTruncated) + + // Get each object storing each file relative to the destination path + for _, object := range resp.Contents { + lastMarker = aws.StringValue(object.Key) + objPath := aws.StringValue(object.Key) + + // If the key ends with a backslash assume it is a directory and ignore + if strings.HasSuffix(objPath, "/") { + continue + } + + // Get the object destination path + objDst, err := filepath.Rel(path, objPath) + if err != nil { + return err + } + objDst = filepath.Join(dst, objDst) + + if err := g.getObject(client, objDst, bucket, objPath, ""); err != nil { + return err + } + } + } + + return nil +} + +func (g *S3Getter) GetFile(dst string, u *url.URL) error { + region, bucket, path, version, creds, err := g.parseUrl(u) + if err != nil { + return err + } + + config := g.getAWSConfig(region, u, creds) + sess := session.New(config) + client := s3.New(sess) + return g.getObject(client, dst, bucket, path, version) +} + +func (g *S3Getter) getObject(client *s3.S3, dst, bucket, key, version string) error { + req := &s3.GetObjectInput{ + Bucket: aws.String(bucket), + Key: aws.String(key), + } + if version != "" { + req.VersionId = aws.String(version) + } + + resp, err := client.GetObject(req) + if err != nil { + return err + } + + // Create all the parent directories + if err := os.MkdirAll(filepath.Dir(dst), 0755); err != nil { + return err + } + + f, err := os.Create(dst) + if err != nil { + return err + } + defer f.Close() + + _, err = io.Copy(f, resp.Body) + return err +} + +func (g *S3Getter) getAWSConfig(region string, url *url.URL, creds *credentials.Credentials) *aws.Config { + conf := &aws.Config{} + if creds == nil { + // Grab the metadata URL + metadataURL := os.Getenv("AWS_METADATA_URL") + if metadataURL == "" { + metadataURL = "http://169.254.169.254:80/latest" + } + + creds = credentials.NewChainCredentials( + []credentials.Provider{ + &credentials.EnvProvider{}, + &credentials.SharedCredentialsProvider{Filename: "", Profile: ""}, + &ec2rolecreds.EC2RoleProvider{ + Client: ec2metadata.New(session.New(&aws.Config{ + Endpoint: aws.String(metadataURL), + })), + }, + }) + } + + if creds != nil { + conf.Endpoint = &url.Host + conf.S3ForcePathStyle = aws.Bool(true) + if url.Scheme == "http" { + conf.DisableSSL = aws.Bool(true) + } + } + + conf.Credentials = creds + if region != "" { + conf.Region = aws.String(region) + } + + return conf +} + +func (g *S3Getter) parseUrl(u *url.URL) (region, bucket, path, version string, creds *credentials.Credentials, err error) { + // This just check whether we are dealing with S3 or + // any other S3 compliant service. S3 has a predictable + // url as others do not + if strings.Contains(u.Host, "amazonaws.com") { + // Expected host style: s3.amazonaws.com. They always have 3 parts, + // although the first may differ if we're accessing a specific region. + hostParts := strings.Split(u.Host, ".") + if len(hostParts) != 3 { + err = fmt.Errorf("URL is not a valid S3 URL") + return + } + + // Parse the region out of the first part of the host + region = strings.TrimPrefix(strings.TrimPrefix(hostParts[0], "s3-"), "s3") + if region == "" { + region = "us-east-1" + } + + pathParts := strings.SplitN(u.Path, "/", 3) + if len(pathParts) != 3 { + err = fmt.Errorf("URL is not a valid S3 URL") + return + } + + bucket = pathParts[1] + path = pathParts[2] + version = u.Query().Get("version") + + } else { + pathParts := strings.SplitN(u.Path, "/", 3) + if len(pathParts) != 3 { + err = fmt.Errorf("URL is not a valid S3 complaint URL") + return + } + bucket = pathParts[1] + path = pathParts[2] + version = u.Query().Get("version") + region = u.Query().Get("region") + if region == "" { + region = "us-east-1" + } + } + + _, hasAwsId := u.Query()["aws_access_key_id"] + _, hasAwsSecret := u.Query()["aws_access_key_secret"] + _, hasAwsToken := u.Query()["aws_access_token"] + if hasAwsId || hasAwsSecret || hasAwsToken { + creds = credentials.NewStaticCredentials( + u.Query().Get("aws_access_key_id"), + u.Query().Get("aws_access_key_secret"), + u.Query().Get("aws_access_token"), + ) + } + + return +} diff --git a/vendor/github.com/hashicorp/go-getter/get_s3_test.go b/vendor/github.com/hashicorp/go-getter/get_s3_test.go new file mode 100644 index 000000000..2d9da14cb --- /dev/null +++ b/vendor/github.com/hashicorp/go-getter/get_s3_test.go @@ -0,0 +1,250 @@ +package getter + +import ( + "net/url" + "os" + "path/filepath" + "testing" + + "github.com/aws/aws-sdk-go/aws/awserr" +) + +func init() { + // These are well known restricted IAM keys to a HashiCorp-managed bucket + // in a private AWS account that only has access to the open source test + // resources. + // + // We do the string concat below to avoid AWS autodetection of a key. This + // key is locked down an IAM policy that is read-only so we're purposely + // exposing it. + os.Setenv("AWS_ACCESS_KEY", "AKIAITTDR"+"WY2STXOZE2A") + os.Setenv("AWS_SECRET_KEY", "oMwSyqdass2kPF"+"/7ORZA9dlb/iegz+89B0Cy01Ea") +} + +func TestS3Getter_impl(t *testing.T) { + var _ Getter = new(S3Getter) +} + +func TestS3Getter(t *testing.T) { + g := new(S3Getter) + dst := tempDir(t) + + // With a dir that doesn't exist + err := g.Get( + dst, testURL("https://s3.amazonaws.com/hc-oss-test/go-getter/folder")) + if err != nil { + t.Fatalf("err: %s", err) + } + + // Verify the main file exists + mainPath := filepath.Join(dst, "main.tf") + if _, err := os.Stat(mainPath); err != nil { + t.Fatalf("err: %s", err) + } +} + +func TestS3Getter_subdir(t *testing.T) { + g := new(S3Getter) + dst := tempDir(t) + + // With a dir that doesn't exist + err := g.Get( + dst, testURL("https://s3.amazonaws.com/hc-oss-test/go-getter/folder/subfolder")) + if err != nil { + t.Fatalf("err: %s", err) + } + + // Verify the main file exists + subPath := filepath.Join(dst, "sub.tf") + if _, err := os.Stat(subPath); err != nil { + t.Fatalf("err: %s", err) + } +} + +func TestS3Getter_GetFile(t *testing.T) { + g := new(S3Getter) + dst := tempFile(t) + + // Download + err := g.GetFile( + dst, testURL("https://s3.amazonaws.com/hc-oss-test/go-getter/folder/main.tf")) + if err != nil { + t.Fatalf("err: %s", err) + } + + // Verify the main file exists + if _, err := os.Stat(dst); err != nil { + t.Fatalf("err: %s", err) + } + assertContents(t, dst, "# Main\n") +} + +func TestS3Getter_GetFile_badParams(t *testing.T) { + g := new(S3Getter) + dst := tempFile(t) + + // Download + err := g.GetFile( + dst, + testURL("https://s3.amazonaws.com/hc-oss-test/go-getter/folder/main.tf?aws_access_key_id=foo&aws_access_key_secret=bar&aws_access_token=baz")) + if err == nil { + t.Fatalf("expected error, got none") + } + + if reqerr, ok := err.(awserr.RequestFailure); !ok || reqerr.StatusCode() != 403 { + t.Fatalf("expected InvalidAccessKeyId error") + } +} + +func TestS3Getter_GetFile_notfound(t *testing.T) { + g := new(S3Getter) + dst := tempFile(t) + + // Download + err := g.GetFile( + dst, testURL("https://s3.amazonaws.com/hc-oss-test/go-getter/folder/404.tf")) + if err == nil { + t.Fatalf("expected error, got none") + } +} + +func TestS3Getter_ClientMode_dir(t *testing.T) { + g := new(S3Getter) + + // Check client mode on a key prefix with only a single key. + mode, err := g.ClientMode( + testURL("https://s3.amazonaws.com/hc-oss-test/go-getter/folder")) + if err != nil { + t.Fatalf("err: %s", err) + } + if mode != ClientModeDir { + t.Fatal("expect ClientModeDir") + } +} + +func TestS3Getter_ClientMode_file(t *testing.T) { + g := new(S3Getter) + + // Check client mode on a key prefix which contains sub-keys. + mode, err := g.ClientMode( + testURL("https://s3.amazonaws.com/hc-oss-test/go-getter/folder/main.tf")) + if err != nil { + t.Fatalf("err: %s", err) + } + if mode != ClientModeFile { + t.Fatal("expect ClientModeFile") + } +} + +func TestS3Getter_ClientMode_notfound(t *testing.T) { + g := new(S3Getter) + + // Check the client mode when a non-existent key is looked up. This does not + // return an error, but rather should just return the file mode so that S3 + // can return an appropriate error later on. This also checks that the + // prefix is handled properly (e.g., "/fold" and "/folder" don't put the + // client mode into "dir". + mode, err := g.ClientMode( + testURL("https://s3.amazonaws.com/hc-oss-test/go-getter/fold")) + if err != nil { + t.Fatalf("err: %s", err) + } + if mode != ClientModeFile { + t.Fatal("expect ClientModeFile") + } +} + +func TestS3Getter_ClientMode_collision(t *testing.T) { + g := new(S3Getter) + + // Check that the client mode is "file" if there is both an object and a + // folder with a common prefix (i.e., a "collision" in the namespace). + mode, err := g.ClientMode( + testURL("https://s3.amazonaws.com/hc-oss-test/go-getter/collision/foo")) + if err != nil { + t.Fatalf("err: %s", err) + } + if mode != ClientModeFile { + t.Fatal("expect ClientModeFile") + } +} + +func TestS3Getter_Url(t *testing.T) { + var s3tests = []struct { + name string + url string + region string + bucket string + path string + version string + }{ + { + name: "AWSv1234", + url: "s3::https://s3-eu-west-1.amazonaws.com/bucket/foo/bar.baz?version=1234", + region: "eu-west-1", + bucket: "bucket", + path: "foo/bar.baz", + version: "1234", + }, + { + name: "localhost-1", + url: "s3::http://127.0.0.1:9000/test-bucket/hello.txt?aws_access_key_id=TESTID&aws_access_key_secret=TestSecret®ion=us-east-2&version=1", + region: "us-east-2", + bucket: "test-bucket", + path: "hello.txt", + version: "1", + }, + { + name: "localhost-2", + url: "s3::http://127.0.0.1:9000/test-bucket/hello.txt?aws_access_key_id=TESTID&aws_access_key_secret=TestSecret&version=1", + region: "us-east-1", + bucket: "test-bucket", + path: "hello.txt", + version: "1", + }, + { + name: "localhost-3", + url: "s3::http://127.0.0.1:9000/test-bucket/hello.txt?aws_access_key_id=TESTID&aws_access_key_secret=TestSecret", + region: "us-east-1", + bucket: "test-bucket", + path: "hello.txt", + version: "", + }, + } + + for i, pt := range s3tests { + t.Run(pt.name, func(t *testing.T) { + g := new(S3Getter) + forced, src := getForcedGetter(pt.url) + u, err := url.Parse(src) + + if err != nil { + t.Errorf("test %d: unexpected error: %s", i, err) + } + if forced != "s3" { + t.Fatalf("expected forced protocol to be s3") + } + + region, bucket, path, version, creds, err := g.parseUrl(u) + + if err != nil { + t.Fatalf("err: %s", err) + } + if region != pt.region { + t.Fatalf("expected %s, got %s", pt.region, region) + } + if bucket != pt.bucket { + t.Fatalf("expected %s, got %s", pt.bucket, bucket) + } + if path != pt.path { + t.Fatalf("expected %s, got %s", pt.path, path) + } + if version != pt.version { + t.Fatalf("expected %s, got %s", pt.version, version) + } + if &creds == nil { + t.Fatalf("expected to not be nil") + } + }) + } +}