Add all dependency of go-getter

This commit is contained in:
Jingfang Liu
2018-08-15 11:34:38 -07:00
parent c9a8bc1121
commit ec95e5f97e
2894 changed files with 1945864 additions and 0 deletions

22068
vendor/github.com/aws/aws-sdk-go/service/s3/api.go generated vendored Normal file

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,58 @@
package s3
import (
"bytes"
"testing"
"time"
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/awstesting/unit"
)
func BenchmarkPresign_GetObject(b *testing.B) {
sess := unit.Session
svc := New(sess)
for i := 0; i < b.N; i++ {
req, _ := svc.GetObjectRequest(&GetObjectInput{
Bucket: aws.String("mock-bucket"),
Key: aws.String("mock-key"),
})
u, h, err := req.PresignRequest(15 * time.Minute)
if err != nil {
b.Fatalf("expect no error, got %v", err)
}
if len(u) == 0 {
b.Fatalf("expect url, got none")
}
if len(h) != 0 {
b.Fatalf("no signed headers, got %v", h)
}
}
}
func BenchmarkPresign_PutObject(b *testing.B) {
sess := unit.Session
svc := New(sess)
body := make([]byte, 1024*1024*20)
for i := 0; i < b.N; i++ {
req, _ := svc.PutObjectRequest(&PutObjectInput{
Bucket: aws.String("mock-bucket"),
Key: aws.String("mock-key"),
Body: bytes.NewReader(body),
})
u, h, err := req.PresignRequest(15 * time.Minute)
if err != nil {
b.Fatalf("expect no error, got %v", err)
}
if len(u) == 0 {
b.Fatalf("expect url, got none")
}
if len(h) == 0 {
b.Fatalf("expect signed header, got none")
}
}
}

View File

@@ -0,0 +1,249 @@
package s3
import (
"bytes"
"crypto/md5"
"crypto/sha256"
"encoding/base64"
"encoding/hex"
"fmt"
"hash"
"io"
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/awserr"
"github.com/aws/aws-sdk-go/aws/request"
"github.com/aws/aws-sdk-go/internal/sdkio"
)
const (
contentMD5Header = "Content-Md5"
contentSha256Header = "X-Amz-Content-Sha256"
amzTeHeader = "X-Amz-Te"
amzTxEncodingHeader = "X-Amz-Transfer-Encoding"
appendMD5TxEncoding = "append-md5"
)
// contentMD5 computes and sets the HTTP Content-MD5 header for requests that
// require it.
func contentMD5(r *request.Request) {
h := md5.New()
if !aws.IsReaderSeekable(r.Body) {
if r.Config.Logger != nil {
r.Config.Logger.Log(fmt.Sprintf(
"Unable to compute Content-MD5 for unseekable body, S3.%s",
r.Operation.Name))
}
return
}
if _, err := copySeekableBody(h, r.Body); err != nil {
r.Error = awserr.New("ContentMD5", "failed to compute body MD5", err)
return
}
// encode the md5 checksum in base64 and set the request header.
v := base64.StdEncoding.EncodeToString(h.Sum(nil))
r.HTTPRequest.Header.Set(contentMD5Header, v)
}
// computeBodyHashes will add Content MD5 and Content Sha256 hashes to the
// request. If the body is not seekable or S3DisableContentMD5Validation set
// this handler will be ignored.
func computeBodyHashes(r *request.Request) {
if aws.BoolValue(r.Config.S3DisableContentMD5Validation) {
return
}
if r.IsPresigned() {
return
}
if r.Error != nil || !aws.IsReaderSeekable(r.Body) {
return
}
var md5Hash, sha256Hash hash.Hash
hashers := make([]io.Writer, 0, 2)
// Determine upfront which hashes can be set without overriding user
// provide header data.
if v := r.HTTPRequest.Header.Get(contentMD5Header); len(v) == 0 {
md5Hash = md5.New()
hashers = append(hashers, md5Hash)
}
if v := r.HTTPRequest.Header.Get(contentSha256Header); len(v) == 0 {
sha256Hash = sha256.New()
hashers = append(hashers, sha256Hash)
}
// Create the destination writer based on the hashes that are not already
// provided by the user.
var dst io.Writer
switch len(hashers) {
case 0:
return
case 1:
dst = hashers[0]
default:
dst = io.MultiWriter(hashers...)
}
if _, err := copySeekableBody(dst, r.Body); err != nil {
r.Error = awserr.New("BodyHashError", "failed to compute body hashes", err)
return
}
// For the hashes created, set the associated headers that the user did not
// already provide.
if md5Hash != nil {
sum := make([]byte, md5.Size)
encoded := make([]byte, md5Base64EncLen)
base64.StdEncoding.Encode(encoded, md5Hash.Sum(sum[0:0]))
r.HTTPRequest.Header[contentMD5Header] = []string{string(encoded)}
}
if sha256Hash != nil {
encoded := make([]byte, sha256HexEncLen)
sum := make([]byte, sha256.Size)
hex.Encode(encoded, sha256Hash.Sum(sum[0:0]))
r.HTTPRequest.Header[contentSha256Header] = []string{string(encoded)}
}
}
const (
md5Base64EncLen = (md5.Size + 2) / 3 * 4 // base64.StdEncoding.EncodedLen
sha256HexEncLen = sha256.Size * 2 // hex.EncodedLen
)
func copySeekableBody(dst io.Writer, src io.ReadSeeker) (int64, error) {
curPos, err := src.Seek(0, sdkio.SeekCurrent)
if err != nil {
return 0, err
}
// hash the body. seek back to the first position after reading to reset
// the body for transmission. copy errors may be assumed to be from the
// body.
n, err := io.Copy(dst, src)
if err != nil {
return n, err
}
_, err = src.Seek(curPos, sdkio.SeekStart)
if err != nil {
return n, err
}
return n, nil
}
// Adds the x-amz-te: append_md5 header to the request. This requests the service
// responds with a trailing MD5 checksum.
//
// Will not ask for append MD5 if disabled, the request is presigned or,
// or the API operation does not support content MD5 validation.
func askForTxEncodingAppendMD5(r *request.Request) {
if aws.BoolValue(r.Config.S3DisableContentMD5Validation) {
return
}
if r.IsPresigned() {
return
}
r.HTTPRequest.Header.Set(amzTeHeader, appendMD5TxEncoding)
}
func useMD5ValidationReader(r *request.Request) {
if r.Error != nil {
return
}
if v := r.HTTPResponse.Header.Get(amzTxEncodingHeader); v != appendMD5TxEncoding {
return
}
var bodyReader *io.ReadCloser
var contentLen int64
switch tv := r.Data.(type) {
case *GetObjectOutput:
bodyReader = &tv.Body
contentLen = aws.Int64Value(tv.ContentLength)
// Update ContentLength hiden the trailing MD5 checksum.
tv.ContentLength = aws.Int64(contentLen - md5.Size)
tv.ContentRange = aws.String(r.HTTPResponse.Header.Get("X-Amz-Content-Range"))
default:
r.Error = awserr.New("ChecksumValidationError",
fmt.Sprintf("%s: %s header received on unsupported API, %s",
amzTxEncodingHeader, appendMD5TxEncoding, r.Operation.Name,
), nil)
return
}
if contentLen < md5.Size {
r.Error = awserr.New("ChecksumValidationError",
fmt.Sprintf("invalid Content-Length %d for %s %s",
contentLen, appendMD5TxEncoding, amzTxEncodingHeader,
), nil)
return
}
// Wrap and swap the response body reader with the validation reader.
*bodyReader = newMD5ValidationReader(*bodyReader, contentLen-md5.Size)
}
type md5ValidationReader struct {
rawReader io.ReadCloser
payload io.Reader
hash hash.Hash
payloadLen int64
read int64
}
func newMD5ValidationReader(reader io.ReadCloser, payloadLen int64) *md5ValidationReader {
h := md5.New()
return &md5ValidationReader{
rawReader: reader,
payload: io.TeeReader(&io.LimitedReader{R: reader, N: payloadLen}, h),
hash: h,
payloadLen: payloadLen,
}
}
func (v *md5ValidationReader) Read(p []byte) (n int, err error) {
n, err = v.payload.Read(p)
if err != nil && err != io.EOF {
return n, err
}
v.read += int64(n)
if err == io.EOF {
if v.read != v.payloadLen {
return n, io.ErrUnexpectedEOF
}
expectSum := make([]byte, md5.Size)
actualSum := make([]byte, md5.Size)
if _, sumReadErr := io.ReadFull(v.rawReader, expectSum); sumReadErr != nil {
return n, sumReadErr
}
actualSum = v.hash.Sum(actualSum[0:0])
if !bytes.Equal(expectSum, actualSum) {
return n, awserr.New("InvalidChecksum",
fmt.Sprintf("expected MD5 checksum %s, got %s",
hex.EncodeToString(expectSum),
hex.EncodeToString(actualSum),
),
nil)
}
}
return n, err
}
func (v *md5ValidationReader) Close() error {
return v.rawReader.Close()
}

View File

@@ -0,0 +1,523 @@
package s3
import (
"bytes"
"crypto/md5"
"fmt"
"io"
"io/ioutil"
"net/http"
"strings"
"testing"
"time"
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/request"
"github.com/aws/aws-sdk-go/internal/sdkio"
)
type errorReader struct{}
func (errorReader) Read([]byte) (int, error) {
return 0, fmt.Errorf("errorReader error")
}
func (errorReader) Seek(int64, int) (int64, error) {
return 0, nil
}
func TestComputeBodyHases(t *testing.T) {
bodyContent := []byte("bodyContent goes here")
cases := []struct {
Req *request.Request
ExpectMD5 string
ExpectSHA256 string
Error string
DisableContentMD5 bool
Presigned bool
}{
{
Req: &request.Request{
HTTPRequest: &http.Request{
Header: http.Header{},
},
Body: bytes.NewReader(bodyContent),
},
ExpectMD5: "CqD6NNPvoNOBT/5pkjtzOw==",
ExpectSHA256: "3ff09c8b42a58a905e27835919ede45b61722e7cd400f30101bd9ed1a69a1825",
},
{
Req: &request.Request{
HTTPRequest: &http.Request{
Header: func() http.Header {
h := http.Header{}
h.Set(contentMD5Header, "MD5AlreadySet")
return h
}(),
},
Body: bytes.NewReader(bodyContent),
},
ExpectMD5: "MD5AlreadySet",
ExpectSHA256: "3ff09c8b42a58a905e27835919ede45b61722e7cd400f30101bd9ed1a69a1825",
},
{
Req: &request.Request{
HTTPRequest: &http.Request{
Header: func() http.Header {
h := http.Header{}
h.Set(contentSha256Header, "SHA256AlreadySet")
return h
}(),
},
Body: bytes.NewReader(bodyContent),
},
ExpectMD5: "CqD6NNPvoNOBT/5pkjtzOw==",
ExpectSHA256: "SHA256AlreadySet",
},
{
Req: &request.Request{
HTTPRequest: &http.Request{
Header: func() http.Header {
h := http.Header{}
h.Set(contentMD5Header, "MD5AlreadySet")
h.Set(contentSha256Header, "SHA256AlreadySet")
return h
}(),
},
Body: bytes.NewReader(bodyContent),
},
ExpectMD5: "MD5AlreadySet",
ExpectSHA256: "SHA256AlreadySet",
},
{
Req: &request.Request{
HTTPRequest: &http.Request{
Header: http.Header{},
},
// Non-seekable reader
Body: aws.ReadSeekCloser(bytes.NewBuffer(bodyContent)),
},
ExpectMD5: "",
ExpectSHA256: "",
},
{
Req: &request.Request{
HTTPRequest: &http.Request{
Header: http.Header{},
},
// Empty seekable body
Body: aws.ReadSeekCloser(bytes.NewReader(nil)),
},
ExpectMD5: "1B2M2Y8AsgTpgAmY7PhCfg==",
ExpectSHA256: "e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855",
},
{
Req: &request.Request{
HTTPRequest: &http.Request{
Header: http.Header{},
},
// failure while reading reader
Body: errorReader{},
},
ExpectMD5: "",
ExpectSHA256: "",
Error: "errorReader error",
},
{
// Disabled ContentMD5 validation
Req: &request.Request{
HTTPRequest: &http.Request{
Header: http.Header{},
},
Body: bytes.NewReader(bodyContent),
},
ExpectMD5: "",
ExpectSHA256: "",
DisableContentMD5: true,
},
{
// Disabled ContentMD5 validation
Req: &request.Request{
HTTPRequest: &http.Request{
Header: http.Header{},
},
Body: bytes.NewReader(bodyContent),
},
ExpectMD5: "",
ExpectSHA256: "",
Presigned: true,
},
}
for i, c := range cases {
c.Req.Config.S3DisableContentMD5Validation = aws.Bool(c.DisableContentMD5)
if c.Presigned {
c.Req.ExpireTime = 10 * time.Minute
}
computeBodyHashes(c.Req)
if e, a := c.ExpectMD5, c.Req.HTTPRequest.Header.Get(contentMD5Header); e != a {
t.Errorf("%d, expect %v md5, got %v", i, e, a)
}
if e, a := c.ExpectSHA256, c.Req.HTTPRequest.Header.Get(contentSha256Header); e != a {
t.Errorf("%d, expect %v sha256, got %v", i, e, a)
}
if len(c.Error) != 0 {
if c.Req.Error == nil {
t.Fatalf("%d, expect error, got none", i)
}
if e, a := c.Error, c.Req.Error.Error(); !strings.Contains(a, e) {
t.Errorf("%d, expect %v error to be in %v", i, e, a)
}
} else if c.Req.Error != nil {
t.Errorf("%d, expect no error, got %v", i, c.Req.Error)
}
}
}
func BenchmarkComputeBodyHashes(b *testing.B) {
body := bytes.NewReader(make([]byte, 2*1024))
req := &request.Request{
HTTPRequest: &http.Request{
Header: http.Header{},
},
Body: body,
}
b.ResetTimer()
for i := 0; i < b.N; i++ {
computeBodyHashes(req)
if req.Error != nil {
b.Fatalf("expect no error, got %v", req.Error)
}
req.HTTPRequest.Header = http.Header{}
body.Seek(0, sdkio.SeekStart)
}
}
func TestAskForTxEncodingAppendMD5(t *testing.T) {
cases := []struct {
DisableContentMD5 bool
Presigned bool
}{
{DisableContentMD5: true},
{DisableContentMD5: false},
{Presigned: true},
}
for i, c := range cases {
req := &request.Request{
HTTPRequest: &http.Request{
Header: http.Header{},
},
Config: aws.Config{
S3DisableContentMD5Validation: aws.Bool(c.DisableContentMD5),
},
}
if c.Presigned {
req.ExpireTime = 10 * time.Minute
}
askForTxEncodingAppendMD5(req)
v := req.HTTPRequest.Header.Get(amzTeHeader)
expectHeader := !(c.DisableContentMD5 || c.Presigned)
if e, a := expectHeader, len(v) != 0; e != a {
t.Errorf("%d, expect %t disable content MD5, got %t, %s", i, e, a, v)
}
}
}
func TestUseMD5ValidationReader(t *testing.T) {
body := []byte("create a really cool md5 checksum of me")
bodySum := md5.Sum(body)
bodyWithSum := append(body, bodySum[:]...)
emptyBodySum := md5.Sum([]byte{})
cases := []struct {
Req *request.Request
Error string
Validate func(outupt interface{}) error
}{
{
// Positive: Use Validation reader
Req: &request.Request{
HTTPResponse: &http.Response{
Header: func() http.Header {
h := http.Header{}
h.Set(amzTxEncodingHeader, appendMD5TxEncoding)
return h
}(),
},
Data: &GetObjectOutput{
Body: ioutil.NopCloser(bytes.NewReader(bodyWithSum)),
ContentLength: aws.Int64(int64(len(bodyWithSum))),
},
},
Validate: func(output interface{}) error {
getObjOut := output.(*GetObjectOutput)
reader, ok := getObjOut.Body.(*md5ValidationReader)
if !ok {
return fmt.Errorf("expect %T updated body reader, got %T",
(*md5ValidationReader)(nil), getObjOut.Body)
}
if reader.rawReader == nil {
return fmt.Errorf("expect rawReader not to be nil")
}
if reader.payload == nil {
return fmt.Errorf("expect payload not to be nil")
}
if e, a := int64(len(bodyWithSum)-md5.Size), reader.payloadLen; e != a {
return fmt.Errorf("expect %v payload len, got %v", e, a)
}
if reader.hash == nil {
return fmt.Errorf("expect hash not to be nil")
}
return nil
},
},
{
// Positive: Use Validation reader, empty object
Req: &request.Request{
HTTPResponse: &http.Response{
Header: func() http.Header {
h := http.Header{}
h.Set(amzTxEncodingHeader, appendMD5TxEncoding)
return h
}(),
},
Data: &GetObjectOutput{
Body: ioutil.NopCloser(bytes.NewReader(emptyBodySum[:])),
ContentLength: aws.Int64(int64(len(emptyBodySum[:]))),
},
},
Validate: func(output interface{}) error {
getObjOut := output.(*GetObjectOutput)
reader, ok := getObjOut.Body.(*md5ValidationReader)
if !ok {
return fmt.Errorf("expect %T updated body reader, got %T",
(*md5ValidationReader)(nil), getObjOut.Body)
}
if reader.rawReader == nil {
return fmt.Errorf("expect rawReader not to be nil")
}
if reader.payload == nil {
return fmt.Errorf("expect payload not to be nil")
}
if e, a := int64(len(emptyBodySum)-md5.Size), reader.payloadLen; e != a {
return fmt.Errorf("expect %v payload len, got %v", e, a)
}
if reader.hash == nil {
return fmt.Errorf("expect hash not to be nil")
}
return nil
},
},
{
// Negative: amzTxEncoding header not set
Req: &request.Request{
HTTPResponse: &http.Response{
Header: http.Header{},
},
Data: &GetObjectOutput{
Body: ioutil.NopCloser(bytes.NewReader(body)),
ContentLength: aws.Int64(int64(len(body))),
},
},
Validate: func(output interface{}) error {
getObjOut := output.(*GetObjectOutput)
reader, ok := getObjOut.Body.(*md5ValidationReader)
if ok {
return fmt.Errorf("expect body reader not to be %T",
reader)
}
return nil
},
},
{
// Negative: Not GetObjectOutput type.
Req: &request.Request{
Operation: &request.Operation{
Name: "PutObject",
},
HTTPResponse: &http.Response{
Header: func() http.Header {
h := http.Header{}
h.Set(amzTxEncodingHeader, appendMD5TxEncoding)
return h
}(),
},
Data: &PutObjectOutput{},
},
Error: "header received on unsupported API",
Validate: func(output interface{}) error {
_, ok := output.(*PutObjectOutput)
if !ok {
return fmt.Errorf("expect %T output not to change, got %T",
(*PutObjectOutput)(nil), output)
}
return nil
},
},
{
// Negative: invalid content length.
Req: &request.Request{
HTTPResponse: &http.Response{
Header: func() http.Header {
h := http.Header{}
h.Set(amzTxEncodingHeader, appendMD5TxEncoding)
return h
}(),
},
Data: &GetObjectOutput{
Body: ioutil.NopCloser(bytes.NewReader(bodyWithSum)),
ContentLength: aws.Int64(-1),
},
},
Error: "invalid Content-Length -1",
Validate: func(output interface{}) error {
getObjOut := output.(*GetObjectOutput)
reader, ok := getObjOut.Body.(*md5ValidationReader)
if ok {
return fmt.Errorf("expect body reader not to be %T",
reader)
}
return nil
},
},
{
// Negative: invalid content length, < md5.Size.
Req: &request.Request{
HTTPResponse: &http.Response{
Header: func() http.Header {
h := http.Header{}
h.Set(amzTxEncodingHeader, appendMD5TxEncoding)
return h
}(),
},
Data: &GetObjectOutput{
Body: ioutil.NopCloser(bytes.NewReader(make([]byte, 5))),
ContentLength: aws.Int64(5),
},
},
Error: "invalid Content-Length 5",
Validate: func(output interface{}) error {
getObjOut := output.(*GetObjectOutput)
reader, ok := getObjOut.Body.(*md5ValidationReader)
if ok {
return fmt.Errorf("expect body reader not to be %T",
reader)
}
return nil
},
},
}
for i, c := range cases {
useMD5ValidationReader(c.Req)
if len(c.Error) != 0 {
if c.Req.Error == nil {
t.Fatalf("%d, expect error, got none", i)
}
if e, a := c.Error, c.Req.Error.Error(); !strings.Contains(a, e) {
t.Errorf("%d, expect %v error to be in %v", i, e, a)
}
} else if c.Req.Error != nil {
t.Errorf("%d, expect no error, got %v", i, c.Req.Error)
}
if c.Validate != nil {
if err := c.Validate(c.Req.Data); err != nil {
t.Errorf("%d, expect Data to validate, got %v", i, err)
}
}
}
}
func TestReaderMD5Validation(t *testing.T) {
body := []byte("create a really cool md5 checksum of me")
bodySum := md5.Sum(body)
bodyWithSum := append(body, bodySum[:]...)
emptyBodySum := md5.Sum([]byte{})
badBodySum := append(body, emptyBodySum[:]...)
cases := []struct {
Content []byte
ContentReader io.ReadCloser
PayloadLen int64
Error string
}{
{
Content: bodyWithSum,
PayloadLen: int64(len(body)),
},
{
Content: emptyBodySum[:],
PayloadLen: 0,
},
{
Content: badBodySum,
PayloadLen: int64(len(body)),
Error: "expected MD5 checksum",
},
{
Content: emptyBodySum[:len(emptyBodySum)-2],
PayloadLen: 0,
Error: "unexpected EOF",
},
{
Content: body,
PayloadLen: int64(len(body) * 2),
Error: "unexpected EOF",
},
{
ContentReader: ioutil.NopCloser(errorReader{}),
PayloadLen: int64(len(body)),
Error: "errorReader error",
},
}
for i, c := range cases {
reader := c.ContentReader
if reader == nil {
reader = ioutil.NopCloser(bytes.NewReader(c.Content))
}
v := newMD5ValidationReader(reader, c.PayloadLen)
var actual bytes.Buffer
n, err := io.Copy(&actual, v)
if len(c.Error) != 0 {
if err == nil {
t.Errorf("%d, expect error, got none", i)
}
if e, a := c.Error, err.Error(); !strings.Contains(a, e) {
t.Errorf("%d, expect %v error to be in %v", i, e, a)
}
continue
} else if err != nil {
t.Errorf("%d, expect no error, got %v", i, err)
continue
}
if e, a := c.PayloadLen, n; e != a {
t.Errorf("%d, expect %v len, got %v", i, e, a)
}
if e, a := c.Content[:c.PayloadLen], actual.Bytes(); !bytes.Equal(e, a) {
t.Errorf("%d, expect:\n%v\nactual:\n%v", i, e, a)
}
}
}

View File

@@ -0,0 +1,106 @@
package s3
import (
"io/ioutil"
"regexp"
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/awserr"
"github.com/aws/aws-sdk-go/aws/awsutil"
"github.com/aws/aws-sdk-go/aws/request"
)
var reBucketLocation = regexp.MustCompile(`>([^<>]+)<\/Location`)
// NormalizeBucketLocation is a utility function which will update the
// passed in value to always be a region ID. Generally this would be used
// with GetBucketLocation API operation.
//
// Replaces empty string with "us-east-1", and "EU" with "eu-west-1".
//
// See http://docs.aws.amazon.com/AmazonS3/latest/API/RESTBucketGETlocation.html
// for more information on the values that can be returned.
func NormalizeBucketLocation(loc string) string {
switch loc {
case "":
loc = "us-east-1"
case "EU":
loc = "eu-west-1"
}
return loc
}
// NormalizeBucketLocationHandler is a request handler which will update the
// GetBucketLocation's result LocationConstraint value to always be a region ID.
//
// Replaces empty string with "us-east-1", and "EU" with "eu-west-1".
//
// See http://docs.aws.amazon.com/AmazonS3/latest/API/RESTBucketGETlocation.html
// for more information on the values that can be returned.
//
// req, result := svc.GetBucketLocationRequest(&s3.GetBucketLocationInput{
// Bucket: aws.String(bucket),
// })
// req.Handlers.Unmarshal.PushBackNamed(NormalizeBucketLocationHandler)
// err := req.Send()
var NormalizeBucketLocationHandler = request.NamedHandler{
Name: "awssdk.s3.NormalizeBucketLocation",
Fn: func(req *request.Request) {
if req.Error != nil {
return
}
out := req.Data.(*GetBucketLocationOutput)
loc := NormalizeBucketLocation(aws.StringValue(out.LocationConstraint))
out.LocationConstraint = aws.String(loc)
},
}
// WithNormalizeBucketLocation is a request option which will update the
// GetBucketLocation's result LocationConstraint value to always be a region ID.
//
// Replaces empty string with "us-east-1", and "EU" with "eu-west-1".
//
// See http://docs.aws.amazon.com/AmazonS3/latest/API/RESTBucketGETlocation.html
// for more information on the values that can be returned.
//
// result, err := svc.GetBucketLocationWithContext(ctx,
// &s3.GetBucketLocationInput{
// Bucket: aws.String(bucket),
// },
// s3.WithNormalizeBucketLocation,
// )
func WithNormalizeBucketLocation(r *request.Request) {
r.Handlers.Unmarshal.PushBackNamed(NormalizeBucketLocationHandler)
}
func buildGetBucketLocation(r *request.Request) {
if r.DataFilled() {
out := r.Data.(*GetBucketLocationOutput)
b, err := ioutil.ReadAll(r.HTTPResponse.Body)
if err != nil {
r.Error = awserr.New("SerializationError", "failed reading response body", err)
return
}
match := reBucketLocation.FindSubmatch(b)
if len(match) > 1 {
loc := string(match[1])
out.LocationConstraint = aws.String(loc)
}
}
}
func populateLocationConstraint(r *request.Request) {
if r.ParamsFilled() && aws.StringValue(r.Config.Region) != "us-east-1" {
in := r.Params.(*CreateBucketInput)
if in.CreateBucketConfiguration == nil {
r.Params = awsutil.CopyOf(r.Params)
in = r.Params.(*CreateBucketInput)
in.CreateBucketConfiguration = &CreateBucketConfiguration{
LocationConstraint: r.Config.Region,
}
}
}
}

View File

@@ -0,0 +1,141 @@
package s3_test
import (
"bytes"
"io/ioutil"
"net/http"
"testing"
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/awsutil"
"github.com/aws/aws-sdk-go/aws/request"
"github.com/aws/aws-sdk-go/awstesting/unit"
"github.com/aws/aws-sdk-go/service/s3"
)
var s3LocationTests = []struct {
body string
loc string
}{
{`<?xml version="1.0" encoding="UTF-8"?><LocationConstraint xmlns="http://s3.amazonaws.com/doc/2006-03-01/"/>`, ``},
{`<?xml version="1.0" encoding="UTF-8"?><LocationConstraint xmlns="http://s3.amazonaws.com/doc/2006-03-01/">EU</LocationConstraint>`, `EU`},
}
func TestGetBucketLocation(t *testing.T) {
for _, test := range s3LocationTests {
s := s3.New(unit.Session)
s.Handlers.Send.Clear()
s.Handlers.Send.PushBack(func(r *request.Request) {
reader := ioutil.NopCloser(bytes.NewReader([]byte(test.body)))
r.HTTPResponse = &http.Response{StatusCode: 200, Body: reader}
})
resp, err := s.GetBucketLocation(&s3.GetBucketLocationInput{Bucket: aws.String("bucket")})
if err != nil {
t.Errorf("expected no error, but received %v", err)
}
if test.loc == "" {
if v := resp.LocationConstraint; v != nil {
t.Errorf("expect location constraint to be nil, got %s", *v)
}
} else {
if e, a := test.loc, *resp.LocationConstraint; e != a {
t.Errorf("expect %s location constraint, got %v", e, a)
}
}
}
}
func TestNormalizeBucketLocation(t *testing.T) {
cases := []struct {
In, Out string
}{
{"", "us-east-1"},
{"EU", "eu-west-1"},
{"us-east-1", "us-east-1"},
{"something", "something"},
}
for i, c := range cases {
actual := s3.NormalizeBucketLocation(c.In)
if e, a := c.Out, actual; e != a {
t.Errorf("%d, expect %s bucket location, got %s", i, e, a)
}
}
}
func TestWithNormalizeBucketLocation(t *testing.T) {
req := &request.Request{}
req.ApplyOptions(s3.WithNormalizeBucketLocation)
cases := []struct {
In, Out string
}{
{"", "us-east-1"},
{"EU", "eu-west-1"},
{"us-east-1", "us-east-1"},
{"something", "something"},
}
for i, c := range cases {
req.Data = &s3.GetBucketLocationOutput{
LocationConstraint: aws.String(c.In),
}
req.Handlers.Unmarshal.Run(req)
v := req.Data.(*s3.GetBucketLocationOutput).LocationConstraint
if e, a := c.Out, aws.StringValue(v); e != a {
t.Errorf("%d, expect %s bucket location, got %s", i, e, a)
}
}
}
func TestPopulateLocationConstraint(t *testing.T) {
s := s3.New(unit.Session)
in := &s3.CreateBucketInput{
Bucket: aws.String("bucket"),
}
req, _ := s.CreateBucketRequest(in)
if err := req.Build(); err != nil {
t.Fatalf("expect no error, got %v", err)
}
v, _ := awsutil.ValuesAtPath(req.Params, "CreateBucketConfiguration.LocationConstraint")
if e, a := "mock-region", *(v[0].(*string)); e != a {
t.Errorf("expect %s location constraint, got %s", e, a)
}
if v := in.CreateBucketConfiguration; v != nil {
// don't modify original params
t.Errorf("expect create bucket Configuration to be nil, got %s", *v)
}
}
func TestNoPopulateLocationConstraintIfProvided(t *testing.T) {
s := s3.New(unit.Session)
req, _ := s.CreateBucketRequest(&s3.CreateBucketInput{
Bucket: aws.String("bucket"),
CreateBucketConfiguration: &s3.CreateBucketConfiguration{},
})
if err := req.Build(); err != nil {
t.Fatalf("expect no error, got %v", err)
}
v, _ := awsutil.ValuesAtPath(req.Params, "CreateBucketConfiguration.LocationConstraint")
if l := len(v); l != 0 {
t.Errorf("expect no values, got %d", l)
}
}
func TestNoPopulateLocationConstraintIfClassic(t *testing.T) {
s := s3.New(unit.Session, &aws.Config{Region: aws.String("us-east-1")})
req, _ := s.CreateBucketRequest(&s3.CreateBucketInput{
Bucket: aws.String("bucket"),
})
if err := req.Build(); err != nil {
t.Fatalf("expect no error, got %v", err)
}
v, _ := awsutil.ValuesAtPath(req.Params, "CreateBucketConfiguration.LocationConstraint")
if l := len(v); l != 0 {
t.Errorf("expect no values, got %d", l)
}
}

View File

@@ -0,0 +1,70 @@
package s3
import (
"github.com/aws/aws-sdk-go/aws/client"
"github.com/aws/aws-sdk-go/aws/request"
)
func init() {
initClient = defaultInitClientFn
initRequest = defaultInitRequestFn
}
func defaultInitClientFn(c *client.Client) {
// Support building custom endpoints based on config
c.Handlers.Build.PushFront(updateEndpointForS3Config)
// Require SSL when using SSE keys
c.Handlers.Validate.PushBack(validateSSERequiresSSL)
c.Handlers.Build.PushBack(computeSSEKeys)
// S3 uses custom error unmarshaling logic
c.Handlers.UnmarshalError.Clear()
c.Handlers.UnmarshalError.PushBack(unmarshalError)
}
func defaultInitRequestFn(r *request.Request) {
// Add reuest handlers for specific platforms.
// e.g. 100-continue support for PUT requests using Go 1.6
platformRequestHandlers(r)
switch r.Operation.Name {
case opPutBucketCors, opPutBucketLifecycle, opPutBucketPolicy,
opPutBucketTagging, opDeleteObjects, opPutBucketLifecycleConfiguration,
opPutBucketReplication:
// These S3 operations require Content-MD5 to be set
r.Handlers.Build.PushBack(contentMD5)
case opGetBucketLocation:
// GetBucketLocation has custom parsing logic
r.Handlers.Unmarshal.PushFront(buildGetBucketLocation)
case opCreateBucket:
// Auto-populate LocationConstraint with current region
r.Handlers.Validate.PushFront(populateLocationConstraint)
case opCopyObject, opUploadPartCopy, opCompleteMultipartUpload:
r.Handlers.Unmarshal.PushFront(copyMultipartStatusOKUnmarhsalError)
case opPutObject, opUploadPart:
r.Handlers.Build.PushBack(computeBodyHashes)
// Disabled until #1837 root issue is resolved.
// case opGetObject:
// r.Handlers.Build.PushBack(askForTxEncodingAppendMD5)
// r.Handlers.Unmarshal.PushBack(useMD5ValidationReader)
}
}
// bucketGetter is an accessor interface to grab the "Bucket" field from
// an S3 type.
type bucketGetter interface {
getBucket() string
}
// sseCustomerKeyGetter is an accessor interface to grab the "SSECustomerKey"
// field from an S3 type.
type sseCustomerKeyGetter interface {
getSSECustomerKey() string
}
// copySourceSSECustomerKeyGetter is an accessor interface to grab the
// "CopySourceSSECustomerKey" field from an S3 type.
type copySourceSSECustomerKeyGetter interface {
getCopySourceSSECustomerKey() string
}

View File

@@ -0,0 +1,171 @@
package s3_test
import (
"crypto/md5"
"encoding/base64"
"io/ioutil"
"net/http"
"net/http/httptest"
"strings"
"testing"
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/request"
"github.com/aws/aws-sdk-go/awstesting/unit"
"github.com/aws/aws-sdk-go/service/s3"
)
func assertMD5(t *testing.T, req *request.Request) {
err := req.Build()
if err != nil {
t.Errorf("expected no error, but received %v", err)
}
b, _ := ioutil.ReadAll(req.HTTPRequest.Body)
out := md5.Sum(b)
if len(b) == 0 {
t.Error("expected non-empty value")
}
if e, a := base64.StdEncoding.EncodeToString(out[:]), req.HTTPRequest.Header.Get("Content-MD5"); e != a {
t.Errorf("expected %s, but received %s", e, a)
}
}
func TestMD5InPutBucketCors(t *testing.T) {
svc := s3.New(unit.Session)
req, _ := svc.PutBucketCorsRequest(&s3.PutBucketCorsInput{
Bucket: aws.String("bucketname"),
CORSConfiguration: &s3.CORSConfiguration{
CORSRules: []*s3.CORSRule{
{
AllowedMethods: []*string{aws.String("GET")},
AllowedOrigins: []*string{aws.String("*")},
},
},
},
})
assertMD5(t, req)
}
func TestMD5InPutBucketLifecycle(t *testing.T) {
svc := s3.New(unit.Session)
req, _ := svc.PutBucketLifecycleRequest(&s3.PutBucketLifecycleInput{
Bucket: aws.String("bucketname"),
LifecycleConfiguration: &s3.LifecycleConfiguration{
Rules: []*s3.Rule{
{
ID: aws.String("ID"),
Prefix: aws.String("Prefix"),
Status: aws.String("Enabled"),
},
},
},
})
assertMD5(t, req)
}
func TestMD5InPutBucketPolicy(t *testing.T) {
svc := s3.New(unit.Session)
req, _ := svc.PutBucketPolicyRequest(&s3.PutBucketPolicyInput{
Bucket: aws.String("bucketname"),
Policy: aws.String("{}"),
})
assertMD5(t, req)
}
func TestMD5InPutBucketTagging(t *testing.T) {
svc := s3.New(unit.Session)
req, _ := svc.PutBucketTaggingRequest(&s3.PutBucketTaggingInput{
Bucket: aws.String("bucketname"),
Tagging: &s3.Tagging{
TagSet: []*s3.Tag{
{Key: aws.String("KEY"), Value: aws.String("VALUE")},
},
},
})
assertMD5(t, req)
}
func TestMD5InDeleteObjects(t *testing.T) {
svc := s3.New(unit.Session)
req, _ := svc.DeleteObjectsRequest(&s3.DeleteObjectsInput{
Bucket: aws.String("bucketname"),
Delete: &s3.Delete{
Objects: []*s3.ObjectIdentifier{
{Key: aws.String("key")},
},
},
})
assertMD5(t, req)
}
func TestMD5InPutBucketLifecycleConfiguration(t *testing.T) {
svc := s3.New(unit.Session)
req, _ := svc.PutBucketLifecycleConfigurationRequest(&s3.PutBucketLifecycleConfigurationInput{
Bucket: aws.String("bucketname"),
LifecycleConfiguration: &s3.BucketLifecycleConfiguration{
Rules: []*s3.LifecycleRule{
{Prefix: aws.String("prefix"), Status: aws.String(s3.ExpirationStatusEnabled)},
},
},
})
assertMD5(t, req)
}
const (
metaKeyPrefix = `X-Amz-Meta-`
utf8KeySuffix = `My-Info`
utf8Value = "hello-世界\u0444"
)
func TestPutObjectMetadataWithUnicode(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if e, a := utf8Value, r.Header.Get(metaKeyPrefix+utf8KeySuffix); e != a {
t.Errorf("expected %s, but received %s", e, a)
}
}))
svc := s3.New(unit.Session, &aws.Config{
Endpoint: aws.String(server.URL),
DisableSSL: aws.Bool(true),
})
_, err := svc.PutObject(&s3.PutObjectInput{
Bucket: aws.String("my_bucket"),
Key: aws.String("my_key"),
Body: strings.NewReader(""),
Metadata: func() map[string]*string {
v := map[string]*string{}
v[utf8KeySuffix] = aws.String(utf8Value)
return v
}(),
})
if err != nil {
t.Errorf("expected no error, but received %v", err)
}
}
func TestGetObjectMetadataWithUnicode(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set(metaKeyPrefix+utf8KeySuffix, utf8Value)
}))
svc := s3.New(unit.Session, &aws.Config{
Endpoint: aws.String(server.URL),
DisableSSL: aws.Bool(true),
})
resp, err := svc.GetObject(&s3.GetObjectInput{
Bucket: aws.String("my_bucket"),
Key: aws.String("my_key"),
})
if err != nil {
t.Errorf("expected no error, but received %v", err)
}
resp.Body.Close()
if e, a := utf8Value, *resp.Metadata[utf8KeySuffix]; e != a {
t.Errorf("expected %s, but received %s", e, a)
}
}

26
vendor/github.com/aws/aws-sdk-go/service/s3/doc.go generated vendored Normal file
View File

@@ -0,0 +1,26 @@
// Code generated by private/model/cli/gen-api/main.go. DO NOT EDIT.
// Package s3 provides the client and types for making API
// requests to Amazon Simple Storage Service.
//
// See https://docs.aws.amazon.com/goto/WebAPI/s3-2006-03-01 for more information on this service.
//
// See s3 package documentation for more information.
// https://docs.aws.amazon.com/sdk-for-go/api/service/s3/
//
// Using the Client
//
// To contact Amazon Simple Storage Service with the SDK use the New function to create
// a new service client. With that client you can make API requests to the service.
// These clients are safe to use concurrently.
//
// See the SDK's documentation for more information on how to use the SDK.
// https://docs.aws.amazon.com/sdk-for-go/api/
//
// See aws.Config documentation for more information on configuring SDK clients.
// https://docs.aws.amazon.com/sdk-for-go/api/aws/#Config
//
// See the Amazon Simple Storage Service client S3 for more
// information on creating client for this service.
// https://docs.aws.amazon.com/sdk-for-go/api/service/s3/#New
package s3

View File

@@ -0,0 +1,109 @@
// Upload Managers
//
// The s3manager package's Uploader provides concurrent upload of content to S3
// by taking advantage of S3's Multipart APIs. The Uploader also supports both
// io.Reader for streaming uploads, and will also take advantage of io.ReadSeeker
// for optimizations if the Body satisfies that type. Once the Uploader instance
// is created you can call Upload concurrently from multiple goroutines safely.
//
// // The session the S3 Uploader will use
// sess := session.Must(session.NewSession())
//
// // Create an uploader with the session and default options
// uploader := s3manager.NewUploader(sess)
//
// f, err := os.Open(filename)
// if err != nil {
// return fmt.Errorf("failed to open file %q, %v", filename, err)
// }
//
// // Upload the file to S3.
// result, err := uploader.Upload(&s3manager.UploadInput{
// Bucket: aws.String(myBucket),
// Key: aws.String(myString),
// Body: f,
// })
// if err != nil {
// return fmt.Errorf("failed to upload file, %v", err)
// }
// fmt.Printf("file uploaded to, %s\n", aws.StringValue(result.Location))
//
// See the s3manager package's Uploader type documentation for more information.
// https://docs.aws.amazon.com/sdk-for-go/api/service/s3/s3manager/#Uploader
//
// Download Manager
//
// The s3manager package's Downloader provides concurrently downloading of Objects
// from S3. The Downloader will write S3 Object content with an io.WriterAt.
// Once the Downloader instance is created you can call Download concurrently from
// multiple goroutines safely.
//
// // The session the S3 Downloader will use
// sess := session.Must(session.NewSession())
//
// // Create a downloader with the session and default options
// downloader := s3manager.NewDownloader(sess)
//
// // Create a file to write the S3 Object contents to.
// f, err := os.Create(filename)
// if err != nil {
// return fmt.Errorf("failed to create file %q, %v", filename, err)
// }
//
// // Write the contents of S3 Object to the file
// n, err := downloader.Download(f, &s3.GetObjectInput{
// Bucket: aws.String(myBucket),
// Key: aws.String(myString),
// })
// if err != nil {
// return fmt.Errorf("failed to download file, %v", err)
// }
// fmt.Printf("file downloaded, %d bytes\n", n)
//
// See the s3manager package's Downloader type documentation for more information.
// https://docs.aws.amazon.com/sdk-for-go/api/service/s3/s3manager/#Downloader
//
// Get Bucket Region
//
// GetBucketRegion will attempt to get the region for a bucket using a region
// hint to determine which AWS partition to perform the query on. Use this utility
// to determine the region a bucket is in.
//
// sess := session.Must(session.NewSession())
//
// bucket := "my-bucket"
// region, err := s3manager.GetBucketRegion(ctx, sess, bucket, "us-west-2")
// if err != nil {
// if aerr, ok := err.(awserr.Error); ok && aerr.Code() == "NotFound" {
// fmt.Fprintf(os.Stderr, "unable to find bucket %s's region not found\n", bucket)
// }
// return err
// }
// fmt.Printf("Bucket %s is in %s region\n", bucket, region)
//
// See the s3manager package's GetBucketRegion function documentation for more information
// https://docs.aws.amazon.com/sdk-for-go/api/service/s3/s3manager/#GetBucketRegion
//
// S3 Crypto Client
//
// The s3crypto package provides the tools to upload and download encrypted
// content from S3. The Encryption and Decryption clients can be used concurrently
// once the client is created.
//
// sess := session.Must(session.NewSession())
//
// // Create the decryption client.
// svc := s3crypto.NewDecryptionClient(sess)
//
// // The object will be downloaded from S3 and decrypted locally. By metadata
// // about the object's encryption will instruct the decryption client how
// // decrypt the content of the object. By default KMS is used for keys.
// result, err := svc.GetObject(&s3.GetObjectInput {
// Bucket: aws.String(myBucket),
// Key: aws.String(myKey),
// })
//
// See the s3crypto package documentation for more information.
// https://docs.aws.amazon.com/sdk-for-go/api/service/s3/s3crypto/
//
package s3

48
vendor/github.com/aws/aws-sdk-go/service/s3/errors.go generated vendored Normal file
View File

@@ -0,0 +1,48 @@
// Code generated by private/model/cli/gen-api/main.go. DO NOT EDIT.
package s3
const (
// ErrCodeBucketAlreadyExists for service response error code
// "BucketAlreadyExists".
//
// The requested bucket name is not available. The bucket namespace is shared
// by all users of the system. Please select a different name and try again.
ErrCodeBucketAlreadyExists = "BucketAlreadyExists"
// ErrCodeBucketAlreadyOwnedByYou for service response error code
// "BucketAlreadyOwnedByYou".
ErrCodeBucketAlreadyOwnedByYou = "BucketAlreadyOwnedByYou"
// ErrCodeNoSuchBucket for service response error code
// "NoSuchBucket".
//
// The specified bucket does not exist.
ErrCodeNoSuchBucket = "NoSuchBucket"
// ErrCodeNoSuchKey for service response error code
// "NoSuchKey".
//
// The specified key does not exist.
ErrCodeNoSuchKey = "NoSuchKey"
// ErrCodeNoSuchUpload for service response error code
// "NoSuchUpload".
//
// The specified multipart upload does not exist.
ErrCodeNoSuchUpload = "NoSuchUpload"
// ErrCodeObjectAlreadyInActiveTierError for service response error code
// "ObjectAlreadyInActiveTierError".
//
// This operation is not allowed against this storage tier
ErrCodeObjectAlreadyInActiveTierError = "ObjectAlreadyInActiveTierError"
// ErrCodeObjectNotInActiveTierError for service response error code
// "ObjectNotInActiveTierError".
//
// The source object of the COPY operation is not in the active tier and is
// only stored in Amazon Glacier.
ErrCodeObjectNotInActiveTierError = "ObjectNotInActiveTierError"
)

View File

@@ -0,0 +1,72 @@
package s3
import (
"encoding/csv"
"fmt"
"io"
"os"
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/session"
)
func ExampleS3_SelectObjectContent() {
sess := session.Must(session.NewSession())
svc := New(sess)
/*
Example myObjectKey CSV content:
name,number
gopher,0
ᵷodɥǝɹ,1
*/
// Make the Select Object Content API request using the object uploaded.
resp, err := svc.SelectObjectContent(&SelectObjectContentInput{
Bucket: aws.String("myBucket"),
Key: aws.String("myObjectKey"),
Expression: aws.String("SELECT name FROM S3Object WHERE cast(number as int) < 1"),
ExpressionType: aws.String(ExpressionTypeSql),
InputSerialization: &InputSerialization{
CSV: &CSVInput{
FileHeaderInfo: aws.String(FileHeaderInfoUse),
},
},
OutputSerialization: &OutputSerialization{
CSV: &CSVOutput{},
},
})
if err != nil {
fmt.Fprintf(os.Stderr, "failed making API request, %v\n", err)
return
}
defer resp.EventStream.Close()
results, resultWriter := io.Pipe()
go func() {
defer resultWriter.Close()
for event := range resp.EventStream.Events() {
switch e := event.(type) {
case *RecordsEvent:
resultWriter.Write(e.Payload)
case *StatsEvent:
fmt.Printf("Processed %d bytes\n", *e.Details.BytesProcessed)
}
}
}()
// Printout the results
resReader := csv.NewReader(results)
for {
record, err := resReader.Read()
if err == io.EOF {
break
}
fmt.Println(record)
}
if err := resp.EventStream.Err(); err != nil {
fmt.Fprintf(os.Stderr, "reading from event stream failed, %v\n", err)
}
}

View File

@@ -0,0 +1,239 @@
// Code generated by private/model/cli/gen-api/main.go. DO NOT EDIT.
// +build go1.6
package s3
import (
"bytes"
"io/ioutil"
"net/http"
"reflect"
"testing"
"time"
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/awserr"
"github.com/aws/aws-sdk-go/aws/corehandlers"
"github.com/aws/aws-sdk-go/aws/request"
"github.com/aws/aws-sdk-go/awstesting/unit"
"github.com/aws/aws-sdk-go/private/protocol"
"github.com/aws/aws-sdk-go/private/protocol/eventstream"
"github.com/aws/aws-sdk-go/private/protocol/eventstream/eventstreamapi"
"github.com/aws/aws-sdk-go/private/protocol/eventstream/eventstreamtest"
"github.com/aws/aws-sdk-go/private/protocol/restxml"
)
var _ time.Time
var _ awserr.Error
func TestSelectObjectContent_Read(t *testing.T) {
expectEvents, eventMsgs := mockSelectObjectContentReadEvents()
sess, cleanupFn, err := eventstreamtest.SetupEventStreamSession(t,
eventstreamtest.ServeEventStream{
T: t,
Events: eventMsgs,
},
true,
)
if err != nil {
t.Fatalf("expect no error, %v", err)
}
defer cleanupFn()
svc := New(sess)
resp, err := svc.SelectObjectContent(nil)
if err != nil {
t.Fatalf("expect no error got, %v", err)
}
defer resp.EventStream.Close()
var i int
for event := range resp.EventStream.Events() {
if event == nil {
t.Errorf("%d, expect event, got nil", i)
}
if e, a := expectEvents[i], event; !reflect.DeepEqual(e, a) {
t.Errorf("%d, expect %T %v, got %T %v", i, e, e, a, a)
}
i++
}
if err := resp.EventStream.Err(); err != nil {
t.Errorf("expect no error, %v", err)
}
}
func TestSelectObjectContent_ReadClose(t *testing.T) {
_, eventMsgs := mockSelectObjectContentReadEvents()
sess, cleanupFn, err := eventstreamtest.SetupEventStreamSession(t,
eventstreamtest.ServeEventStream{
T: t,
Events: eventMsgs,
},
true,
)
if err != nil {
t.Fatalf("expect no error, %v", err)
}
defer cleanupFn()
svc := New(sess)
resp, err := svc.SelectObjectContent(nil)
if err != nil {
t.Fatalf("expect no error got, %v", err)
}
resp.EventStream.Close()
<-resp.EventStream.Events()
if err := resp.EventStream.Err(); err != nil {
t.Errorf("expect no error, %v", err)
}
}
func BenchmarkSelectObjectContent_Read(b *testing.B) {
_, eventMsgs := mockSelectObjectContentReadEvents()
var buf bytes.Buffer
encoder := eventstream.NewEncoder(&buf)
for _, msg := range eventMsgs {
if err := encoder.Encode(msg); err != nil {
b.Fatalf("failed to encode message, %v", err)
}
}
stream := &loopReader{source: bytes.NewReader(buf.Bytes())}
sess := unit.Session
svc := New(sess, &aws.Config{
Endpoint: aws.String("https://example.com"),
DisableParamValidation: aws.Bool(true),
})
svc.Handlers.Send.Swap(corehandlers.SendHandler.Name,
request.NamedHandler{Name: "mockSend",
Fn: func(r *request.Request) {
r.HTTPResponse = &http.Response{
Status: "200 OK",
StatusCode: 200,
Header: http.Header{},
Body: ioutil.NopCloser(stream),
}
},
},
)
resp, err := svc.SelectObjectContent(nil)
if err != nil {
b.Fatalf("failed to create request, %v", err)
}
defer resp.EventStream.Close()
b.ResetTimer()
for i := 0; i < b.N; i++ {
if err = resp.EventStream.Err(); err != nil {
b.Fatalf("expect no error, got %v", err)
}
event := <-resp.EventStream.Events()
if event == nil {
b.Fatalf("expect event, got nil, %v, %d", resp.EventStream.Err(), i)
}
}
}
func mockSelectObjectContentReadEvents() (
[]SelectObjectContentEventStreamEvent,
[]eventstream.Message,
) {
expectEvents := []SelectObjectContentEventStreamEvent{
&ContinuationEvent{},
&EndEvent{},
&ProgressEvent{
Details: &Progress{
BytesProcessed: aws.Int64(1234),
BytesReturned: aws.Int64(1234),
BytesScanned: aws.Int64(1234),
},
},
&RecordsEvent{
Payload: []byte("blob value goes here"),
},
&StatsEvent{
Details: &Stats{
BytesProcessed: aws.Int64(1234),
BytesReturned: aws.Int64(1234),
BytesScanned: aws.Int64(1234),
},
},
}
var marshalers request.HandlerList
marshalers.PushBackNamed(restxml.BuildHandler)
payloadMarshaler := protocol.HandlerPayloadMarshal{
Marshalers: marshalers,
}
_ = payloadMarshaler
eventMsgs := []eventstream.Message{
{
Headers: eventstream.Headers{
eventstreamtest.EventMessageTypeHeader,
{
Name: eventstreamapi.EventTypeHeader,
Value: eventstream.StringValue("Cont"),
},
},
},
{
Headers: eventstream.Headers{
eventstreamtest.EventMessageTypeHeader,
{
Name: eventstreamapi.EventTypeHeader,
Value: eventstream.StringValue("End"),
},
},
},
{
Headers: eventstream.Headers{
eventstreamtest.EventMessageTypeHeader,
{
Name: eventstreamapi.EventTypeHeader,
Value: eventstream.StringValue("Progress"),
},
},
Payload: eventstreamtest.MarshalEventPayload(payloadMarshaler, expectEvents[2]),
},
{
Headers: eventstream.Headers{
eventstreamtest.EventMessageTypeHeader,
{
Name: eventstreamapi.EventTypeHeader,
Value: eventstream.StringValue("Records"),
},
},
Payload: expectEvents[3].(*RecordsEvent).Payload,
},
{
Headers: eventstream.Headers{
eventstreamtest.EventMessageTypeHeader,
{
Name: eventstreamapi.EventTypeHeader,
Value: eventstream.StringValue("Stats"),
},
},
Payload: eventstreamtest.MarshalEventPayload(payloadMarshaler, expectEvents[4]),
},
}
return expectEvents, eventMsgs
}
type loopReader struct {
source *bytes.Reader
}
func (c *loopReader) Read(p []byte) (int, error) {
if c.source.Len() == 0 {
c.source.Seek(0, 0)
}
return c.source.Read(p)
}

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,155 @@
package s3
import (
"fmt"
"net/url"
"regexp"
"strings"
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/awserr"
"github.com/aws/aws-sdk-go/aws/request"
)
// an operationBlacklist is a list of operation names that should a
// request handler should not be executed with.
type operationBlacklist []string
// Continue will return true of the Request's operation name is not
// in the blacklist. False otherwise.
func (b operationBlacklist) Continue(r *request.Request) bool {
for i := 0; i < len(b); i++ {
if b[i] == r.Operation.Name {
return false
}
}
return true
}
var accelerateOpBlacklist = operationBlacklist{
opListBuckets, opCreateBucket, opDeleteBucket,
}
// Request handler to automatically add the bucket name to the endpoint domain
// if possible. This style of bucket is valid for all bucket names which are
// DNS compatible and do not contain "."
func updateEndpointForS3Config(r *request.Request) {
forceHostStyle := aws.BoolValue(r.Config.S3ForcePathStyle)
accelerate := aws.BoolValue(r.Config.S3UseAccelerate)
if accelerate && accelerateOpBlacklist.Continue(r) {
if forceHostStyle {
if r.Config.Logger != nil {
r.Config.Logger.Log("ERROR: aws.Config.S3UseAccelerate is not compatible with aws.Config.S3ForcePathStyle, ignoring S3ForcePathStyle.")
}
}
updateEndpointForAccelerate(r)
} else if !forceHostStyle && r.Operation.Name != opGetBucketLocation {
updateEndpointForHostStyle(r)
}
}
func updateEndpointForHostStyle(r *request.Request) {
bucket, ok := bucketNameFromReqParams(r.Params)
if !ok {
// Ignore operation requests if the bucketname was not provided
// if this is an input validation error the validation handler
// will report it.
return
}
if !hostCompatibleBucketName(r.HTTPRequest.URL, bucket) {
// bucket name must be valid to put into the host
return
}
moveBucketToHost(r.HTTPRequest.URL, bucket)
}
var (
accelElem = []byte("s3-accelerate.dualstack.")
)
func updateEndpointForAccelerate(r *request.Request) {
bucket, ok := bucketNameFromReqParams(r.Params)
if !ok {
// Ignore operation requests if the bucketname was not provided
// if this is an input validation error the validation handler
// will report it.
return
}
if !hostCompatibleBucketName(r.HTTPRequest.URL, bucket) {
r.Error = awserr.New("InvalidParameterException",
fmt.Sprintf("bucket name %s is not compatible with S3 Accelerate", bucket),
nil)
return
}
parts := strings.Split(r.HTTPRequest.URL.Host, ".")
if len(parts) < 3 {
r.Error = awserr.New("InvalidParameterExecption",
fmt.Sprintf("unable to update endpoint host for S3 accelerate, hostname invalid, %s",
r.HTTPRequest.URL.Host), nil)
return
}
if parts[0] == "s3" || strings.HasPrefix(parts[0], "s3-") {
parts[0] = "s3-accelerate"
}
for i := 1; i+1 < len(parts); i++ {
if parts[i] == aws.StringValue(r.Config.Region) {
parts = append(parts[:i], parts[i+1:]...)
break
}
}
r.HTTPRequest.URL.Host = strings.Join(parts, ".")
moveBucketToHost(r.HTTPRequest.URL, bucket)
}
// Attempts to retrieve the bucket name from the request input parameters.
// If no bucket is found, or the field is empty "", false will be returned.
func bucketNameFromReqParams(params interface{}) (string, bool) {
if iface, ok := params.(bucketGetter); ok {
b := iface.getBucket()
return b, len(b) > 0
}
return "", false
}
// hostCompatibleBucketName returns true if the request should
// put the bucket in the host. This is false if S3ForcePathStyle is
// explicitly set or if the bucket is not DNS compatible.
func hostCompatibleBucketName(u *url.URL, bucket string) bool {
// Bucket might be DNS compatible but dots in the hostname will fail
// certificate validation, so do not use host-style.
if u.Scheme == "https" && strings.Contains(bucket, ".") {
return false
}
// if the bucket is DNS compatible
return dnsCompatibleBucketName(bucket)
}
var reDomain = regexp.MustCompile(`^[a-z0-9][a-z0-9\.\-]{1,61}[a-z0-9]$`)
var reIPAddress = regexp.MustCompile(`^(\d+\.){3}\d+$`)
// dnsCompatibleBucketName returns true if the bucket name is DNS compatible.
// Buckets created outside of the classic region MUST be DNS compatible.
func dnsCompatibleBucketName(bucket string) bool {
return reDomain.MatchString(bucket) &&
!reIPAddress.MatchString(bucket) &&
!strings.Contains(bucket, "..")
}
// moveBucketToHost moves the bucket name from the URI path to URL host.
func moveBucketToHost(u *url.URL, bucket string) {
u.Host = bucket + "." + u.Host
u.Path = strings.Replace(u.Path, "/{Bucket}", "", -1)
if u.Path == "" {
u.Path = "/"
}
}

View File

@@ -0,0 +1,179 @@
package s3_test
import (
"encoding/json"
"net/url"
"os"
"path/filepath"
"strings"
"testing"
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/awserr"
"github.com/aws/aws-sdk-go/awstesting/unit"
"github.com/aws/aws-sdk-go/service/s3"
)
type s3BucketTest struct {
bucket string
url string
errCode string
}
var (
sslTests = []s3BucketTest{
{"abc", "https://abc.s3.mock-region.amazonaws.com/", ""},
{"a$b$c", "https://s3.mock-region.amazonaws.com/a%24b%24c", ""},
{"a.b.c", "https://s3.mock-region.amazonaws.com/a.b.c", ""},
{"a..bc", "https://s3.mock-region.amazonaws.com/a..bc", ""},
}
nosslTests = []s3BucketTest{
{"a.b.c", "http://a.b.c.s3.mock-region.amazonaws.com/", ""},
{"a..bc", "http://s3.mock-region.amazonaws.com/a..bc", ""},
}
forcepathTests = []s3BucketTest{
{"abc", "https://s3.mock-region.amazonaws.com/abc", ""},
{"a$b$c", "https://s3.mock-region.amazonaws.com/a%24b%24c", ""},
{"a.b.c", "https://s3.mock-region.amazonaws.com/a.b.c", ""},
{"a..bc", "https://s3.mock-region.amazonaws.com/a..bc", ""},
}
accelerateTests = []s3BucketTest{
{"abc", "https://abc.s3-accelerate.amazonaws.com/", ""},
{"a.b.c", "https://s3.mock-region.amazonaws.com/%7BBucket%7D", "InvalidParameterException"},
{"a$b$c", "https://s3.mock-region.amazonaws.com/%7BBucket%7D", "InvalidParameterException"},
}
accelerateNoSSLTests = []s3BucketTest{
{"abc", "http://abc.s3-accelerate.amazonaws.com/", ""},
{"a.b.c", "http://a.b.c.s3-accelerate.amazonaws.com/", ""},
{"a$b$c", "http://s3.mock-region.amazonaws.com/%7BBucket%7D", "InvalidParameterException"},
}
accelerateDualstack = []s3BucketTest{
{"abc", "https://abc.s3-accelerate.dualstack.amazonaws.com/", ""},
{"a.b.c", "https://s3.dualstack.mock-region.amazonaws.com/%7BBucket%7D", "InvalidParameterException"},
{"a$b$c", "https://s3.dualstack.mock-region.amazonaws.com/%7BBucket%7D", "InvalidParameterException"},
}
)
func runTests(t *testing.T, svc *s3.S3, tests []s3BucketTest) {
for i, test := range tests {
req, _ := svc.ListObjectsRequest(&s3.ListObjectsInput{Bucket: &test.bucket})
req.Build()
if e, a := test.url, req.HTTPRequest.URL.String(); e != a {
t.Errorf("%d, expect url %s, got %s", i, e, a)
}
if test.errCode != "" {
if err := req.Error; err == nil {
t.Fatalf("%d, expect no error", i)
}
if a, e := req.Error.(awserr.Error).Code(), test.errCode; !strings.Contains(a, e) {
t.Errorf("%d, expect error code to contain %q, got %q", i, e, a)
}
}
}
}
func TestAccelerateBucketBuild(t *testing.T) {
s := s3.New(unit.Session, &aws.Config{S3UseAccelerate: aws.Bool(true)})
runTests(t, s, accelerateTests)
}
func TestAccelerateNoSSLBucketBuild(t *testing.T) {
s := s3.New(unit.Session, &aws.Config{S3UseAccelerate: aws.Bool(true), DisableSSL: aws.Bool(true)})
runTests(t, s, accelerateNoSSLTests)
}
func TestAccelerateDualstackBucketBuild(t *testing.T) {
s := s3.New(unit.Session, &aws.Config{
S3UseAccelerate: aws.Bool(true),
UseDualStack: aws.Bool(true),
})
runTests(t, s, accelerateDualstack)
}
func TestHostStyleBucketBuild(t *testing.T) {
s := s3.New(unit.Session)
runTests(t, s, sslTests)
}
func TestHostStyleBucketBuildNoSSL(t *testing.T) {
s := s3.New(unit.Session, &aws.Config{DisableSSL: aws.Bool(true)})
runTests(t, s, nosslTests)
}
func TestPathStyleBucketBuild(t *testing.T) {
s := s3.New(unit.Session, &aws.Config{S3ForcePathStyle: aws.Bool(true)})
runTests(t, s, forcepathTests)
}
func TestHostStyleBucketGetBucketLocation(t *testing.T) {
s := s3.New(unit.Session)
req, _ := s.GetBucketLocationRequest(&s3.GetBucketLocationInput{
Bucket: aws.String("bucket"),
})
req.Build()
if req.Error != nil {
t.Fatalf("expect no error, got %v", req.Error)
}
u, _ := url.Parse(req.HTTPRequest.URL.String())
if e, a := "bucket", u.Host; strings.Contains(a, e) {
t.Errorf("expect %s to not be in %s", e, a)
}
if e, a := "bucket", u.Path; !strings.Contains(a, e) {
t.Errorf("expect %s to be in %s", e, a)
}
}
func TestVirtualHostStyleSuite(t *testing.T) {
f, err := os.Open(filepath.Join("testdata", "virtual_host.json"))
if err != nil {
t.Fatalf("expect no error, %v", err)
}
cases := []struct {
Bucket string
Region string
UseDualStack bool
UseS3Accelerate bool
ConfiguredAddressingStyle string
ExpectedURI string
}{}
decoder := json.NewDecoder(f)
if err := decoder.Decode(&cases); err != nil {
t.Fatalf("expect no error, %v", err)
}
const testPathStyle = "path"
for i, c := range cases {
svc := s3.New(unit.Session, &aws.Config{
Region: &c.Region,
UseDualStack: &c.UseDualStack,
S3UseAccelerate: &c.UseS3Accelerate,
S3ForcePathStyle: aws.Bool(c.ConfiguredAddressingStyle == testPathStyle),
})
req, _ := svc.HeadBucketRequest(&s3.HeadBucketInput{
Bucket: &c.Bucket,
})
req.Build()
if req.Error != nil {
t.Fatalf("expect no error, got %v", req.Error)
}
// Trim trailing '/' that are added by the SDK but not in the tests.
actualURI := strings.TrimRightFunc(
req.HTTPRequest.URL.String(),
func(r rune) bool { return r == '/' },
)
if e, a := c.ExpectedURI, actualURI; e != a {
t.Errorf("%d, expect\n%s\nurl to be\n%s", i, e, a)
}
}
}

View File

@@ -0,0 +1,8 @@
// +build !go1.6
package s3
import "github.com/aws/aws-sdk-go/aws/request"
func platformRequestHandlers(r *request.Request) {
}

View File

@@ -0,0 +1,28 @@
// +build go1.6
package s3
import (
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/request"
)
func platformRequestHandlers(r *request.Request) {
if r.Operation.HTTPMethod == "PUT" {
// 100-Continue should only be used on put requests.
r.Handlers.Sign.PushBack(add100Continue)
}
}
func add100Continue(r *request.Request) {
if aws.BoolValue(r.Config.S3Disable100Continue) {
return
}
if r.HTTPRequest.ContentLength < 1024*1024*2 {
// Ignore requests smaller than 2MB. This helps prevent delaying
// requests unnecessarily.
return
}
r.HTTPRequest.Header.Set("Expect", "100-Continue")
}

View File

@@ -0,0 +1,83 @@
// +build go1.6
package s3_test
import (
"bytes"
"testing"
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/awstesting/unit"
"github.com/aws/aws-sdk-go/service/s3"
)
func TestAdd100Continue_Added(t *testing.T) {
svc := s3.New(unit.Session)
r, _ := svc.PutObjectRequest(&s3.PutObjectInput{
Bucket: aws.String("bucket"),
Key: aws.String("dest"),
Body: bytes.NewReader(make([]byte, 1024*1024*5)),
})
err := r.Sign()
if err != nil {
t.Errorf("expected no error, but received %v", err)
}
if e, a := "100-Continue", r.HTTPRequest.Header.Get("Expect"); e != a {
t.Errorf("expected %s, but received %s", e, a)
}
}
func TestAdd100Continue_SkipDisabled(t *testing.T) {
svc := s3.New(unit.Session, aws.NewConfig().WithS3Disable100Continue(true))
r, _ := svc.PutObjectRequest(&s3.PutObjectInput{
Bucket: aws.String("bucket"),
Key: aws.String("dest"),
Body: bytes.NewReader(make([]byte, 1024*1024*5)),
})
err := r.Sign()
if err != nil {
t.Errorf("expected no error, but received %v", err)
}
if r.HTTPRequest.Header.Get("Expect") != "" {
t.Errorf("expected empty value, but received %s", r.HTTPRequest.Header.Get("Expect"))
}
}
func TestAdd100Continue_SkipNonPUT(t *testing.T) {
svc := s3.New(unit.Session)
r, _ := svc.GetObjectRequest(&s3.GetObjectInput{
Bucket: aws.String("bucket"),
Key: aws.String("dest"),
})
err := r.Sign()
if err != nil {
t.Errorf("expected no error, but received %v", err)
}
if r.HTTPRequest.Header.Get("Expect") != "" {
t.Errorf("expected empty value, but received %s", r.HTTPRequest.Header.Get("Expect"))
}
}
func TestAdd100Continue_SkipTooSmall(t *testing.T) {
svc := s3.New(unit.Session)
r, _ := svc.PutObjectRequest(&s3.PutObjectInput{
Bucket: aws.String("bucket"),
Key: aws.String("dest"),
Body: bytes.NewReader(make([]byte, 1024*1024*1)),
})
err := r.Sign()
if err != nil {
t.Errorf("expected no error, but received %v", err)
}
if r.HTTPRequest.Header.Get("Expect") != "" {
t.Errorf("expected empty value, but received %s", r.HTTPRequest.Header.Get("Expect"))
}
}

View File

@@ -0,0 +1,190 @@
package s3crypto
import (
"bytes"
"crypto/aes"
"crypto/cipher"
"io"
)
// AESCBC is a symmetric crypto algorithm. This algorithm
// requires a padder due to CBC needing to be of the same block
// size. AES CBC is vulnerable to Padding Oracle attacks and
// so should be avoided when possible.
type aesCBC struct {
encrypter cipher.BlockMode
decrypter cipher.BlockMode
padder Padder
}
// newAESCBC creates a new AES CBC cipher. Expects keys to be of
// the correct size.
func newAESCBC(cd CipherData, padder Padder) (Cipher, error) {
block, err := aes.NewCipher(cd.Key)
if err != nil {
return nil, err
}
encrypter := cipher.NewCBCEncrypter(block, cd.IV)
decrypter := cipher.NewCBCDecrypter(block, cd.IV)
return &aesCBC{encrypter, decrypter, padder}, nil
}
// Encrypt will encrypt the data using AES CBC by returning
// an io.Reader. The io.Reader will encrypt the data as Read
// is called.
func (c *aesCBC) Encrypt(src io.Reader) io.Reader {
reader := &cbcEncryptReader{
encrypter: c.encrypter,
src: src,
padder: c.padder,
}
return reader
}
type cbcEncryptReader struct {
encrypter cipher.BlockMode
src io.Reader
padder Padder
size int
buf bytes.Buffer
}
// Read will read from our io.Reader and encrypt the data as necessary.
// Due to padding, we have to do some logic that when we encounter an
// end of file to pad properly.
func (reader *cbcEncryptReader) Read(data []byte) (int, error) {
n, err := reader.src.Read(data)
reader.size += n
blockSize := reader.encrypter.BlockSize()
reader.buf.Write(data[:n])
if err == io.EOF {
b := make([]byte, getSliceSize(blockSize, reader.buf.Len(), len(data)))
n, err = reader.buf.Read(b)
if err != nil && err != io.EOF {
return n, err
}
// The buffer is now empty, we can now pad the data
if reader.buf.Len() == 0 {
b, err = reader.padder.Pad(b[:n], reader.size)
if err != nil {
return n, err
}
n = len(b)
err = io.EOF
}
// We only want to encrypt if we have read anything
if n > 0 {
reader.encrypter.CryptBlocks(data, b)
}
return n, err
}
if err != nil {
return n, err
}
if size := reader.buf.Len(); size >= blockSize {
nBlocks := size / blockSize
if size > len(data) {
nBlocks = len(data) / blockSize
}
if nBlocks > 0 {
b := make([]byte, nBlocks*blockSize)
n, _ = reader.buf.Read(b)
reader.encrypter.CryptBlocks(data, b[:n])
}
} else {
n = 0
}
return n, nil
}
// Decrypt will decrypt the data using AES CBC
func (c *aesCBC) Decrypt(src io.Reader) io.Reader {
return &cbcDecryptReader{
decrypter: c.decrypter,
src: src,
padder: c.padder,
}
}
type cbcDecryptReader struct {
decrypter cipher.BlockMode
src io.Reader
padder Padder
buf bytes.Buffer
}
// Read will read from our io.Reader and decrypt the data as necessary.
// Due to padding, we have to do some logic that when we encounter an
// end of file to pad properly.
func (reader *cbcDecryptReader) Read(data []byte) (int, error) {
n, err := reader.src.Read(data)
blockSize := reader.decrypter.BlockSize()
reader.buf.Write(data[:n])
if err == io.EOF {
b := make([]byte, getSliceSize(blockSize, reader.buf.Len(), len(data)))
n, err = reader.buf.Read(b)
if err != nil && err != io.EOF {
return n, err
}
// We only want to decrypt if we have read anything
if n > 0 {
reader.decrypter.CryptBlocks(data, b)
}
if reader.buf.Len() == 0 {
b, err = reader.padder.Unpad(data[:n])
n = len(b)
if err != nil {
return n, err
}
err = io.EOF
}
return n, err
}
if err != nil {
return n, err
}
if size := reader.buf.Len(); size >= blockSize {
nBlocks := size / blockSize
if size > len(data) {
nBlocks = len(data) / blockSize
}
// The last block is always padded. This will allow us to unpad
// when we receive an io.EOF error
nBlocks -= blockSize
if nBlocks > 0 {
b := make([]byte, nBlocks*blockSize)
n, _ = reader.buf.Read(b)
reader.decrypter.CryptBlocks(data, b[:n])
} else {
n = 0
}
}
return n, nil
}
// getSliceSize will return the correct amount of bytes we need to
// read with regards to padding.
func getSliceSize(blockSize, bufSize, dataSize int) int {
size := bufSize
if bufSize > dataSize {
size = dataSize
}
size = size - (size % blockSize) - blockSize
if size <= 0 {
size = blockSize
}
return size
}

View File

@@ -0,0 +1,73 @@
package s3crypto
import (
"io"
"strings"
)
const (
cbcKeySize = 32
cbcNonceSize = 16
)
type cbcContentCipherBuilder struct {
generator CipherDataGenerator
padder Padder
}
// AESCBCContentCipherBuilder returns a new encryption only mode structure with a specific cipher
// for the master key
func AESCBCContentCipherBuilder(generator CipherDataGenerator, padder Padder) ContentCipherBuilder {
return cbcContentCipherBuilder{generator: generator, padder: padder}
}
func (builder cbcContentCipherBuilder) ContentCipher() (ContentCipher, error) {
cd, err := builder.generator.GenerateCipherData(cbcKeySize, cbcNonceSize)
if err != nil {
return nil, err
}
cd.Padder = builder.padder
return newAESCBCContentCipher(cd)
}
// newAESCBCContentCipher will create a new aes cbc content cipher. If the cipher data's
// will set the CEK algorithm if it hasn't been set.
func newAESCBCContentCipher(cd CipherData) (ContentCipher, error) {
if len(cd.CEKAlgorithm) == 0 {
cd.CEKAlgorithm = strings.Join([]string{AESCBC, cd.Padder.Name()}, "/")
}
cipher, err := newAESCBC(cd, cd.Padder)
if err != nil {
return nil, err
}
return &aesCBCContentCipher{
CipherData: cd,
Cipher: cipher,
}, nil
}
// aesCBCContentCipher will use AES CBC for the main cipher.
type aesCBCContentCipher struct {
CipherData CipherData
Cipher Cipher
}
// EncryptContents will generate a random key and iv and encrypt the data using cbc
func (cc *aesCBCContentCipher) EncryptContents(src io.Reader) (io.Reader, error) {
return cc.Cipher.Encrypt(src), nil
}
// DecryptContents will use the symmetric key provider to instantiate a new CBC cipher.
// We grab a decrypt reader from CBC and wrap it in a CryptoReadCloser. The only error
// expected here is when the key or iv is of invalid length.
func (cc *aesCBCContentCipher) DecryptContents(src io.ReadCloser) (io.ReadCloser, error) {
reader := cc.Cipher.Decrypt(src)
return &CryptoReadCloser{Body: src, Decrypter: reader}, nil
}
// GetCipherData returns cipher data
func (cc aesCBCContentCipher) GetCipherData() CipherData {
return cc.CipherData
}

View File

@@ -0,0 +1,20 @@
package s3crypto_test
import (
"testing"
"github.com/aws/aws-sdk-go/service/s3/s3crypto"
)
func TestAESCBCBuilder(t *testing.T) {
generator := mockGenerator{}
builder := s3crypto.AESCBCContentCipherBuilder(generator, s3crypto.NoPadder)
if builder == nil {
t.Fatal(builder)
}
_, err := builder.ContentCipher()
if err != nil {
t.Fatal(err)
}
}

View File

@@ -0,0 +1,29 @@
package s3crypto
const (
pkcs5BlockSize = 16
)
var aescbcPadding = aescbcPadder{pkcs7Padder{16}}
// AESCBCPadder is used to pad AES encrypted and decrypted data.
// Altough it uses the pkcs5Padder, it isn't following the RFC
// for PKCS5. The only reason why it is called pkcs5Padder is
// due to the Name returning PKCS5Padding.
var AESCBCPadder = Padder(aescbcPadding)
type aescbcPadder struct {
padder pkcs7Padder
}
func (padder aescbcPadder) Pad(b []byte, n int) ([]byte, error) {
return padder.padder.Pad(b, n)
}
func (padder aescbcPadder) Unpad(b []byte) ([]byte, error) {
return padder.padder.Unpad(b)
}
func (padder aescbcPadder) Name() string {
return "PKCS5Padding"
}

View File

@@ -0,0 +1,41 @@
package s3crypto
import (
"bytes"
"fmt"
"testing"
)
func TestAESCBCPadding(t *testing.T) {
for i := 0; i < 16; i++ {
input := make([]byte, i)
expected := append(input, bytes.Repeat([]byte{byte(16 - i)}, 16-i)...)
b, err := AESCBCPadder.Pad(input, len(input))
if err != nil {
t.Fatal("Expected error to be nil but received " + err.Error())
}
if len(b) != len(expected) {
t.Fatal(fmt.Sprintf("Case %d: data is not of the same length", i))
}
if bytes.Compare(b, expected) != 0 {
t.Fatal(fmt.Sprintf("Expected %v but got %v", expected, b))
}
}
}
func TestAESCBCUnpadding(t *testing.T) {
for i := 0; i < 16; i++ {
expected := make([]byte, i)
input := append(expected, bytes.Repeat([]byte{byte(16 - i)}, 16-i)...)
b, err := AESCBCPadder.Unpad(input)
if err != nil {
t.Fatal("Error received, was expecting nil: " + err.Error())
}
if len(b) != len(expected) {
t.Fatal(fmt.Sprintf("Case %d: data is not of the same length", i))
}
if bytes.Compare(b, expected) != 0 {
t.Fatal(fmt.Sprintf("Expected %v but got %v", expected, b))
}
}
}

View File

@@ -0,0 +1,501 @@
package s3crypto
import (
"bytes"
"encoding/hex"
"fmt"
"io"
"testing"
)
func TestAESCBCEncryptDecrypt(t *testing.T) {
var testCases = []struct {
key string
iv string
plaintext string
ciphertext string
decodeHex bool
padder Padder
}{
// Test vectors from RFC 3602: https://tools.ietf.org/html/rfc3602
{
"06a9214036b8a15b512e03d534120006",
"3dafba429d9eb430b422da802c9fac41",
"Single block msg",
"e353779c1079aeb82708942dbe77181a",
false,
NoPadder,
},
{
"c286696d887c9aa0611bbb3e2025a45a",
"562e17996d093d28ddb3ba695a2e6f58",
"000102030405060708090a0b0c0d0e0f101112131415161718191a1b1c1d1e1f",
"d296cd94c2cccf8a3a863028b5e1dc0a7586602d253cfff91b8266bea6d61ab1",
true,
NoPadder,
},
{
"6c3ea0477630ce21a2ce334aa746c2cd",
"c782dc4c098c66cbd9cd27d825682c81",
"This is a 48-byte message (exactly 3 AES blocks)",
"d0a02b3836451753d493665d33f0e8862dea54cdb293abc7506939276772f8d5021c19216bad525c8579695d83ba2684",
false,
NoPadder,
},
{
"56e47a38c5598974bc46903dba290349",
"8ce82eefbea0da3c44699ed7db51b7d9",
"a0a1a2a3a4a5a6a7a8a9aaabacadaeafb0b1b2b3b4b5b6b7b8b9babbbcbdbebfc0c1c2c3c4c5c6c7c8c9cacbcccdcecfd0d1d2d3d4d5d6d7d8d9dadbdcdddedf",
"c30e32ffedc0774e6aff6af0869f71aa0f3af07a9a31a9c684db207eb0ef8e4e35907aa632c3ffdf868bb7b29d3d46ad83ce9f9a102ee99d49a53e87f4c3da55",
true,
NoPadder,
},
{
"11111111111111111111111111111111",
"22222222222222222222222222222222",
"",
"B012949BA07D1A6DCE9DEE67274D41AB",
true,
AESCBCPadder,
},
{
"11111111111111111111111111111111",
"22222222222222222222222222222222",
"41",
"8A11ABA68A566132FFE04DB336621D41",
true,
AESCBCPadder,
},
{
"11111111111111111111111111111111",
"22222222222222222222222222222222",
"4141",
"97D0896E41DFDB5CEA4A9EB70A938CFD",
true,
AESCBCPadder,
},
{
"11111111111111111111111111111111",
"22222222222222222222222222222222",
"414141",
"8464EAD45FA2D8790E8741E32C28083F",
true,
AESCBCPadder,
},
{
"11111111111111111111111111111111",
"22222222222222222222222222222222",
"41414141",
"1E656D6E2745BA9F154FAF136B2BC73D",
true,
AESCBCPadder,
},
{
"11111111111111111111111111111111",
"22222222222222222222222222222222",
"4141414141",
"0B6031C4B230DAC6BD6D3F195645B287",
true,
AESCBCPadder,
},
{
"11111111111111111111111111111111",
"22222222222222222222222222222222",
"414141414141",
"5D09FEB6462BB489489A7E18FD341D9D",
true,
AESCBCPadder,
},
{
"11111111111111111111111111111111",
"22222222222222222222222222222222",
"41414141414141",
"85745E398F2FD1050C2CE8F8614DA369",
true,
AESCBCPadder,
},
{
"11111111111111111111111111111111",
"22222222222222222222222222222222",
"4141414141414141",
"7BE52933970BA7B0FC6FB3FC37648205",
true,
AESCBCPadder,
},
{
"11111111111111111111111111111111",
"22222222222222222222222222222222",
"414141414141414141",
"ED3A1E134EF36CCFE60C8123B4272F89",
true,
AESCBCPadder,
},
{
"11111111111111111111111111111111",
"22222222222222222222222222222222",
"41414141414141414141",
"C3B7C9E177E1052FC736F65FC1E74209",
true,
AESCBCPadder,
},
{
"11111111111111111111111111111111",
"22222222222222222222222222222222",
"4141414141414141414141",
"C3A8B53F7F57F0B9D346FA99810A3C28",
true,
AESCBCPadder,
},
{
"11111111111111111111111111111111",
"22222222222222222222222222222222",
"414141414141414141414141",
"D16B1ECE5BF00AF919E139E99775FF06",
true,
AESCBCPadder,
},
{
"11111111111111111111111111111111",
"22222222222222222222222222222222",
"41414141414141414141414141",
"B258F4DF57FFCA1EFCF8D76140F05139",
true,
AESCBCPadder,
},
{
"11111111111111111111111111111111",
"22222222222222222222222222222222",
"4141414141414141414141414141",
"3CD2282DE24A2CF9E23326CC3DC9077A",
true,
AESCBCPadder,
},
{
"11111111111111111111111111111111",
"22222222222222222222222222222222",
"414141414141414141414141414141",
"3010232E7C752A3B4C9EE428B4C4FE88",
true,
AESCBCPadder,
},
{
"11111111111111111111111111111111",
"22222222222222222222222222222222",
"41414141414141414141414141414141",
"C3304FA46097CBBA59085416764A217A22BC4E6D03BFD2418DD412D1ED1B31AF",
true,
AESCBCPadder,
},
{
"11111111111111111111111111111111",
"22222222222222222222222222222222",
"4141414141414141414141414141414141",
"C3304FA46097CBBA59085416764A217A5427BBD4A4D50776989441370E3B5B16",
true,
AESCBCPadder,
},
{
"11111111111111111111111111111111",
"22222222222222222222222222222222",
"414141414141414141414141414141414141",
"C3304FA46097CBBA59085416764A217A7FF985F55567D1B25EA40E23BB4CB1FE",
true,
AESCBCPadder,
},
{
"11111111111111111111111111111111",
"22222222222222222222222222222222",
"41414141414141414141414141414141414141",
"C3304FA46097CBBA59085416764A217A0835E548C7370D8F8D9925C0E6B54727",
true,
AESCBCPadder,
},
{
"11111111111111111111111111111111",
"22222222222222222222222222222222",
"4141414141414141414141414141414141414141",
"C3304FA46097CBBA59085416764A217ADC0CF1436399E67BC1122B31CB596649",
true,
AESCBCPadder,
},
{
"11111111111111111111111111111111",
"22222222222222222222222222222222",
"414141414141414141414141414141414141414141",
"C3304FA46097CBBA59085416764A217A3D096F0DEAFF91938B82E5D404B0B065",
true,
AESCBCPadder,
},
{
"11111111111111111111111111111111",
"22222222222222222222222222222222",
"41414141414141414141414141414141414141414141",
"C3304FA46097CBBA59085416764A217AD56ABA897A355CF307CCB74226243192",
true,
AESCBCPadder,
},
{
"11111111111111111111111111111111",
"22222222222222222222222222222222",
"4141414141414141414141414141414141414141414141",
"C3304FA46097CBBA59085416764A217A151284F950B1B1DBCAD6D9E7900DF4E6",
true,
AESCBCPadder,
},
{
"11111111111111111111111111111111",
"22222222222222222222222222222222",
"414141414141414141414141414141414141414141414141",
"C3304FA46097CBBA59085416764A217AEF85A612514121C299A1D87116C4A182",
true,
AESCBCPadder,
},
{
"11111111111111111111111111111111",
"22222222222222222222222222222222",
"41414141414141414141414141414141414141414141414141",
"C3304FA46097CBBA59085416764A217A67F157569BFB4013EA3AD16DB8C69AD6",
true,
AESCBCPadder,
},
{
"11111111111111111111111111111111",
"22222222222222222222222222222222",
"4141414141414141414141414141414141414141414141414141",
"C3304FA46097CBBA59085416764A217AF8520D191F6ACBD88B2140588B91C697",
true,
AESCBCPadder,
},
{
"11111111111111111111111111111111",
"22222222222222222222222222222222",
"414141414141414141414141414141414141414141414141414141",
"C3304FA46097CBBA59085416764A217ADD8BBAA71745669B96F2683E2F5AEC35",
true,
AESCBCPadder,
},
{
"11111111111111111111111111111111",
"22222222222222222222222222222222",
"41414141414141414141414141414141414141414141414141414141",
"C3304FA46097CBBA59085416764A217AFB2D4282817D7EC6B33EFAD7AA14A3C5",
true,
AESCBCPadder,
},
{
"11111111111111111111111111111111",
"22222222222222222222222222222222",
"4141414141414141414141414141414141414141414141414141414141",
"C3304FA46097CBBA59085416764A217A459B89E7E0DAF3DA654576B60B2DA7CE",
true,
AESCBCPadder,
},
{
"11111111111111111111111111111111",
"22222222222222222222222222222222",
"414141414141414141414141414141414141414141414141414141414141",
"C3304FA46097CBBA59085416764A217A65759F23F9789D05B23D5DBAA9E32036",
true,
AESCBCPadder,
},
{
"11111111111111111111111111111111",
"22222222222222222222222222222222",
"41414141414141414141414141414141414141414141414141414141414141",
"C3304FA46097CBBA59085416764A217A03C78FBD5E2CB08B3B6D181E23FBDE79",
true,
AESCBCPadder,
},
{
"11111111111111111111111111111111",
"22222222222222222222222222222222",
"4141414141414141414141414141414141414141414141414141414141414141",
"C3304FA46097CBBA59085416764A217ACEF79EE1163A2F52888F87A3979EB3CA013D941FBBDE56C106C482CD022F290F",
true,
AESCBCPadder,
},
{
"11111111111111111111111111111111",
"22222222222222222222222222222222",
"414141414141414141414141414141414141414141414141414141414141414141",
"C3304FA46097CBBA59085416764A217ACEF79EE1163A2F52888F87A3979EB3CA0645D313AC3C29B79DB1AA2E00A5B393",
true,
AESCBCPadder,
},
{
"11111111111111111111111111111111",
"22222222222222222222222222222222",
"41414141414141414141414141414141414141414141414141414141414141414141",
"C3304FA46097CBBA59085416764A217ACEF79EE1163A2F52888F87A3979EB3CA2ED0FD8048053BF22EBE501D82C4B3F1",
true,
AESCBCPadder,
},
{
"11111111111111111111111111111111",
"22222222222222222222222222222222",
"4141414141414141414141414141414141414141414141414141414141414141414141",
"C3304FA46097CBBA59085416764A217ACEF79EE1163A2F52888F87A3979EB3CAC57D706C7866A01D6E913F98AE57EE54",
true,
AESCBCPadder,
},
{
"11111111111111111111111111111111",
"22222222222222222222222222222222",
"414141414141414141414141414141414141414141414141414141414141414141414141",
"C3304FA46097CBBA59085416764A217ACEF79EE1163A2F52888F87A3979EB3CAB7FC1241FAFDFE45C4FF982D5DC1DAEF",
true,
AESCBCPadder,
},
{
"11111111111111111111111111111111",
"22222222222222222222222222222222",
"41414141414141414141414141414141414141414141414141414141414141414141414141",
"C3304FA46097CBBA59085416764A217ACEF79EE1163A2F52888F87A3979EB3CA7063EA296922DE8BDFD3B29D786C5F91",
true,
AESCBCPadder,
},
{
"11111111111111111111111111111111",
"22222222222222222222222222222222",
"4141414141414141414141414141414141414141414141414141414141414141414141414141",
"C3304FA46097CBBA59085416764A217ACEF79EE1163A2F52888F87A3979EB3CA3A4603475F4AFDBFADC6E7FA908188B1",
true,
AESCBCPadder,
},
{
"11111111111111111111111111111111",
"22222222222222222222222222222222",
"414141414141414141414141414141414141414141414141414141414141414141414141414141",
"C3304FA46097CBBA59085416764A217ACEF79EE1163A2F52888F87A3979EB3CA3365C63C2AF2A6C8FB4D0E9ED3C6FDA3",
true,
AESCBCPadder,
},
{
"11111111111111111111111111111111",
"22222222222222222222222222222222",
"41414141414141414141414141414141414141414141414141414141414141414141414141414141",
"C3304FA46097CBBA59085416764A217ACEF79EE1163A2F52888F87A3979EB3CA78BCC1874C0B7EB52645FC8F03B9C9CF",
true,
AESCBCPadder,
},
{
"11111111111111111111111111111111",
"22222222222222222222222222222222",
"4141414141414141414141414141414141414141414141414141414141414141414141414141414141",
"C3304FA46097CBBA59085416764A217ACEF79EE1163A2F52888F87A3979EB3CA9B7A31397718EECB89B9E9CCCD729326",
true,
AESCBCPadder,
},
{
"11111111111111111111111111111111",
"22222222222222222222222222222222",
"414141414141414141414141414141414141414141414141414141414141414141414141414141414141",
"C3304FA46097CBBA59085416764A217ACEF79EE1163A2F52888F87A3979EB3CAB15EA8A67E9E9FADB4249710277F3D4F",
true,
AESCBCPadder,
},
{
"11111111111111111111111111111111",
"22222222222222222222222222222222",
"41414141414141414141414141414141414141414141414141414141414141414141414141414141414141",
"C3304FA46097CBBA59085416764A217ACEF79EE1163A2F52888F87A3979EB3CA94641D6A076193C660632CEA3F9CB02C",
true,
AESCBCPadder,
},
{
"11111111111111111111111111111111",
"22222222222222222222222222222222",
"4141414141414141414141414141414141414141414141414141414141414141414141414141414141414141",
"C3304FA46097CBBA59085416764A217ACEF79EE1163A2F52888F87A3979EB3CAB2170A08417BE77F0EAA9110F4790E12",
true,
AESCBCPadder,
},
{
"11111111111111111111111111111111",
"22222222222222222222222222222222",
"414141414141414141414141414141414141414141414141414141414141414141414141414141414141414141",
"C3304FA46097CBBA59085416764A217ACEF79EE1163A2F52888F87A3979EB3CA4E30F1CD7B2256ABD57DC3DAB05376C9",
true,
AESCBCPadder,
},
{
"11111111111111111111111111111111",
"22222222222222222222222222222222",
"41414141414141414141414141414141414141414141414141414141414141414141414141414141414141414141",
"C3304FA46097CBBA59085416764A217ACEF79EE1163A2F52888F87A3979EB3CA9909B7B93D01BDAAC22D15AF34DF1EEF",
true,
AESCBCPadder,
},
{
"11111111111111111111111111111111",
"22222222222222222222222222222222",
"4141414141414141414141414141414141414141414141414141414141414141414141414141414141414141414141",
"C3304FA46097CBBA59085416764A217ACEF79EE1163A2F52888F87A3979EB3CAD97F5D1206F00E5C7225CAD81CCD4027",
true,
AESCBCPadder,
},
{
"11111111111111111111111111111111",
"22222222222222222222222222222222",
"414141414141414141414141414141414141414141414141414141414141414141414141414141414141414141414141",
"C3304FA46097CBBA59085416764A217ACEF79EE1163A2F52888F87A3979EB3CA570CBB001A0C87558906B60C884AB5F41DA97CEF2A9401BC6DD0D22A54DBAD6D",
true,
AESCBCPadder,
},
}
for i, testCase := range testCases {
key, _ := hex.DecodeString(testCase.key)
iv, _ := hex.DecodeString(testCase.iv)
cd := CipherData{
Key: key,
IV: iv,
}
cbc, err := newAESCBC(cd, testCase.padder)
if err != nil {
t.Fatal(fmt.Sprintf("Case %d: Expected no error for cipher creation, but received: %v", i, err.Error()))
}
plaintext := []byte(testCase.plaintext)
if testCase.decodeHex {
plaintext, _ = hex.DecodeString(testCase.plaintext)
}
cipherdata := cbc.Encrypt(bytes.NewReader(plaintext))
ciphertext := []byte{}
b := make([]byte, 19)
err = nil
n := 0
for err != io.EOF {
n, err = cipherdata.Read(b)
ciphertext = append(ciphertext, b[:n]...)
}
if err != io.EOF {
t.Fatal(fmt.Sprintf("Case %d: Expected no error during io reading, but received: %v", i, err.Error()))
}
expectedData, _ := hex.DecodeString(testCase.ciphertext)
if bytes.Compare(expectedData, ciphertext) != 0 {
t.Log("\n", ciphertext, "\n", expectedData)
t.Fatal(fmt.Sprintf("Case %d: AES CBC encryption fails. Data is not the same", i))
}
plaindata := cbc.Decrypt(bytes.NewReader(ciphertext))
plaintextDecrypted := []byte{}
err = nil
for err != io.EOF {
n, err = plaindata.Read(b)
plaintextDecrypted = append(plaintextDecrypted, b[:n]...)
}
if err != io.EOF {
t.Fatal(fmt.Sprintf("Case %d: Expected no error during io reading, but received: %v", i, err.Error()))
}
if bytes.Compare(plaintext, plaintextDecrypted) != 0 {
t.Log("\n", plaintext, "\n", plaintextDecrypted)
t.Fatal(fmt.Sprintf("Case %d: AES CBC decryption fails. Data is not the same", i))
}
}
}

View File

@@ -0,0 +1,105 @@
package s3crypto
import (
"bytes"
"crypto/aes"
"crypto/cipher"
"io"
"io/ioutil"
)
// AESGCM Symmetric encryption algorithm. Since Golang designed this
// with only TLS in mind. We have to load it all into memory meaning
// this isn't streamed.
type aesGCM struct {
aead cipher.AEAD
nonce []byte
}
// newAESGCM creates a new AES GCM cipher. Expects keys to be of
// the correct size.
//
// Example:
//
// cd := &s3crypto.CipherData{
// Key: key,
// "IV": iv,
// }
// cipher, err := s3crypto.newAESGCM(cd)
func newAESGCM(cd CipherData) (Cipher, error) {
block, err := aes.NewCipher(cd.Key)
if err != nil {
return nil, err
}
aesgcm, err := cipher.NewGCM(block)
if err != nil {
return nil, err
}
return &aesGCM{aesgcm, cd.IV}, nil
}
// Encrypt will encrypt the data using AES GCM
// Tag will be included as the last 16 bytes of the slice
func (c *aesGCM) Encrypt(src io.Reader) io.Reader {
reader := &gcmEncryptReader{
encrypter: c.aead,
nonce: c.nonce,
src: src,
}
return reader
}
type gcmEncryptReader struct {
encrypter cipher.AEAD
nonce []byte
src io.Reader
buf *bytes.Buffer
}
func (reader *gcmEncryptReader) Read(data []byte) (int, error) {
if reader.buf == nil {
b, err := ioutil.ReadAll(reader.src)
if err != nil {
return len(b), err
}
b = reader.encrypter.Seal(b[:0], reader.nonce, b, nil)
reader.buf = bytes.NewBuffer(b)
}
return reader.buf.Read(data)
}
// Decrypt will decrypt the data using AES GCM
func (c *aesGCM) Decrypt(src io.Reader) io.Reader {
return &gcmDecryptReader{
decrypter: c.aead,
nonce: c.nonce,
src: src,
}
}
type gcmDecryptReader struct {
decrypter cipher.AEAD
nonce []byte
src io.Reader
buf *bytes.Buffer
}
func (reader *gcmDecryptReader) Read(data []byte) (int, error) {
if reader.buf == nil {
b, err := ioutil.ReadAll(reader.src)
if err != nil {
return len(b), err
}
b, err = reader.decrypter.Open(b[:0], reader.nonce, b, nil)
if err != nil {
return len(b), err
}
reader.buf = bytes.NewBuffer(b)
}
return reader.buf.Read(data)
}

View File

@@ -0,0 +1,68 @@
package s3crypto
import (
"io"
)
const (
gcmKeySize = 32
gcmNonceSize = 12
)
type gcmContentCipherBuilder struct {
generator CipherDataGenerator
}
// AESGCMContentCipherBuilder returns a new encryption only mode structure with a specific cipher
// for the master key
func AESGCMContentCipherBuilder(generator CipherDataGenerator) ContentCipherBuilder {
return gcmContentCipherBuilder{generator}
}
func (builder gcmContentCipherBuilder) ContentCipher() (ContentCipher, error) {
cd, err := builder.generator.GenerateCipherData(gcmKeySize, gcmNonceSize)
if err != nil {
return nil, err
}
return newAESGCMContentCipher(cd)
}
func newAESGCMContentCipher(cd CipherData) (ContentCipher, error) {
cd.CEKAlgorithm = AESGCMNoPadding
cd.TagLength = "128"
cipher, err := newAESGCM(cd)
if err != nil {
return nil, err
}
return &aesGCMContentCipher{
CipherData: cd,
Cipher: cipher,
}, nil
}
// AESGCMContentCipher will use AES GCM for the main cipher.
type aesGCMContentCipher struct {
CipherData CipherData
Cipher Cipher
}
// EncryptContents will generate a random key and iv and encrypt the data using cbc
func (cc *aesGCMContentCipher) EncryptContents(src io.Reader) (io.Reader, error) {
return cc.Cipher.Encrypt(src), nil
}
// DecryptContents will use the symmetric key provider to instantiate a new GCM cipher.
// We grab a decrypt reader from gcm and wrap it in a CryptoReadCloser. The only error
// expected here is when the key or iv is of invalid length.
func (cc *aesGCMContentCipher) DecryptContents(src io.ReadCloser) (io.ReadCloser, error) {
reader := cc.Cipher.Decrypt(src)
return &CryptoReadCloser{Body: src, Decrypter: reader}, nil
}
// GetCipherData returns cipher data
func (cc aesGCMContentCipher) GetCipherData() CipherData {
return cc.CipherData
}

View File

@@ -0,0 +1,28 @@
package s3crypto_test
import (
"testing"
"github.com/aws/aws-sdk-go/service/s3/s3crypto"
)
func TestAESGCMContentCipherBuilder(t *testing.T) {
generator := mockGenerator{}
if builder := s3crypto.AESGCMContentCipherBuilder(generator); builder == nil {
t.Error("expected non-nil value")
}
}
func TestAESGCMContentCipherNewEncryptor(t *testing.T) {
generator := mockGenerator{}
builder := s3crypto.AESGCMContentCipherBuilder(generator)
cipher, err := builder.ContentCipher()
if err != nil {
t.Errorf("expected no error, but received %v", err)
}
if cipher == nil {
t.Errorf("expected non-nil vaue")
}
}

View File

@@ -0,0 +1,81 @@
package s3crypto
import (
"bytes"
"encoding/hex"
"io/ioutil"
"testing"
)
// AES GCM
func TestAES_GCM_NIST_gcmEncryptExtIV256_PTLen_128_Test_0(t *testing.T) {
iv, _ := hex.DecodeString("0d18e06c7c725ac9e362e1ce")
key, _ := hex.DecodeString("31bdadd96698c204aa9ce1448ea94ae1fb4a9a0b3c9d773b51bb1822666b8f22")
plaintext, _ := hex.DecodeString("2db5168e932556f8089a0622981d017d")
expected, _ := hex.DecodeString("fa4362189661d163fcd6a56d8bf0405ad636ac1bbedd5cc3ee727dc2ab4a9489")
tag, _ := hex.DecodeString("d636ac1bbedd5cc3ee727dc2ab4a9489")
aesgcmTest(t, iv, key, plaintext, expected, tag)
}
func TestAES_GCM_NIST_gcmEncryptExtIV256_PTLen_104_Test_3(t *testing.T) {
iv, _ := hex.DecodeString("4742357c335913153ff0eb0f")
key, _ := hex.DecodeString("e5a0eb92cc2b064e1bc80891faf1fab5e9a17a9c3a984e25416720e30e6c2b21")
plaintext, _ := hex.DecodeString("8499893e16b0ba8b007d54665a")
expected, _ := hex.DecodeString("eb8e6175f1fe38eb1acf95fd5188a8b74bb74fda553e91020a23deed45")
tag, _ := hex.DecodeString("88a8b74bb74fda553e91020a23deed45")
aesgcmTest(t, iv, key, plaintext, expected, tag)
}
func TestAES_GCM_NIST_gcmEncryptExtIV256_PTLen_256_Test_6(t *testing.T) {
iv, _ := hex.DecodeString("a291484c3de8bec6b47f525f")
key, _ := hex.DecodeString("37f39137416bafde6f75022a7a527cc593b6000a83ff51ec04871a0ff5360e4e")
plaintext, _ := hex.DecodeString("fafd94cede8b5a0730394bec68a8e77dba288d6ccaa8e1563a81d6e7ccc7fc97")
expected, _ := hex.DecodeString("44dc868006b21d49284016565ffb3979cc4271d967628bf7cdaf86db888e92e501a2b578aa2f41ec6379a44a31cc019c")
tag, _ := hex.DecodeString("01a2b578aa2f41ec6379a44a31cc019c")
aesgcmTest(t, iv, key, plaintext, expected, tag)
}
func TestAES_GCM_NIST_gcmEncryptExtIV256_PTLen_408_Test_8(t *testing.T) {
iv, _ := hex.DecodeString("92f258071d79af3e63672285")
key, _ := hex.DecodeString("595f259c55abe00ae07535ca5d9b09d6efb9f7e9abb64605c337acbd6b14fc7e")
plaintext, _ := hex.DecodeString("a6fee33eb110a2d769bbc52b0f36969c287874f665681477a25fc4c48015c541fbe2394133ba490a34ee2dd67b898177849a91")
expected, _ := hex.DecodeString("bbca4a9e09ae9690c0f6f8d405e53dccd666aa9c5fa13c8758bc30abe1ddd1bcce0d36a1eaaaaffef20cd3c5970b9673f8a65c26ccecb9976fd6ac9c2c0f372c52c821")
tag, _ := hex.DecodeString("26ccecb9976fd6ac9c2c0f372c52c821")
aesgcmTest(t, iv, key, plaintext, expected, tag)
}
func aesgcmTest(t *testing.T, iv, key, plaintext, expected, tag []byte) {
cd := CipherData{
Key: key,
IV: iv,
}
gcm, err := newAESGCM(cd)
if err != nil {
t.Errorf("expected no error, but received %v", err)
}
cipherdata := gcm.Encrypt(bytes.NewReader(plaintext))
ciphertext, err := ioutil.ReadAll(cipherdata)
if err != nil {
t.Errorf("expected no error, but received %v", err)
}
// splitting tag and ciphertext
etag := ciphertext[len(ciphertext)-16:]
if !bytes.Equal(etag, tag) {
t.Errorf("expected tags to be equivalent")
}
if !bytes.Equal(ciphertext, expected) {
t.Errorf("expected ciphertext to be equivalent")
}
data := gcm.Decrypt(bytes.NewReader(ciphertext))
text, err := ioutil.ReadAll(data)
if err != nil {
t.Errorf("expected no error, but received %v", err)
}
if !bytes.Equal(plaintext, text) {
t.Errorf("expected ciphertext to be equivalent")
}
}

View File

@@ -0,0 +1,43 @@
package s3crypto
import (
"io"
)
// Cipher interface allows for either encryption and decryption of an object
type Cipher interface {
Encrypter
Decrypter
}
// Encrypter interface with only the encrypt method
type Encrypter interface {
Encrypt(io.Reader) io.Reader
}
// Decrypter interface with only the decrypt method
type Decrypter interface {
Decrypt(io.Reader) io.Reader
}
// CryptoReadCloser handles closing of the body and allowing reads from the decrypted
// content.
type CryptoReadCloser struct {
Body io.ReadCloser
Decrypter io.Reader
isClosed bool
}
// Close lets the CryptoReadCloser satisfy io.ReadCloser interface
func (rc *CryptoReadCloser) Close() error {
rc.isClosed = true
return rc.Body.Close()
}
// Read lets the CryptoReadCloser satisfy io.ReadCloser interface
func (rc *CryptoReadCloser) Read(b []byte) (int, error) {
if rc.isClosed {
return 0, io.EOF
}
return rc.Decrypter.Read(b)
}

View File

@@ -0,0 +1,31 @@
package s3crypto
import "io"
// ContentCipherBuilder is a builder interface that builds
// ciphers for each request.
type ContentCipherBuilder interface {
ContentCipher() (ContentCipher, error)
}
// ContentCipher deals with encrypting and decrypting content
type ContentCipher interface {
EncryptContents(io.Reader) (io.Reader, error)
DecryptContents(io.ReadCloser) (io.ReadCloser, error)
GetCipherData() CipherData
}
// CipherData is used for content encryption. It is used for storing the
// metadata of the encrypted content.
type CipherData struct {
Key []byte
IV []byte
WrapAlgorithm string
CEKAlgorithm string
TagLength string
MaterialDescription MaterialDescription
// EncryptedKey should be populated when calling GenerateCipherData
EncryptedKey []byte
Padder Padder
}

View File

@@ -0,0 +1,40 @@
package s3crypto_test
import (
"io/ioutil"
"strings"
"testing"
"github.com/aws/aws-sdk-go/service/s3/s3crypto"
)
func TestCryptoReadCloserRead(t *testing.T) {
expectedStr := "HELLO WORLD "
str := strings.NewReader(expectedStr)
rc := &s3crypto.CryptoReadCloser{Body: ioutil.NopCloser(str), Decrypter: str}
b, err := ioutil.ReadAll(rc)
if err != nil {
t.Errorf("expected no error, but received %v", err)
}
if expectedStr != string(b) {
t.Errorf("expected %s, but received %s", expectedStr, string(b))
}
}
func TestCryptoReadCloserClose(t *testing.T) {
data := "HELLO WORLD "
expectedStr := ""
str := strings.NewReader(data)
rc := &s3crypto.CryptoReadCloser{Body: ioutil.NopCloser(str), Decrypter: str}
rc.Close()
b, err := ioutil.ReadAll(rc)
if err != nil {
t.Errorf("expected no error, but received %v", err)
}
if expectedStr != string(b) {
t.Errorf("expected %s, but received %s", expectedStr, string(b))
}
}

View File

@@ -0,0 +1,111 @@
package s3crypto
import (
"encoding/base64"
"strconv"
"strings"
"github.com/aws/aws-sdk-go/aws/awserr"
)
func (client *DecryptionClient) contentCipherFromEnvelope(env Envelope) (ContentCipher, error) {
wrap, err := client.wrapFromEnvelope(env)
if err != nil {
return nil, err
}
return client.cekFromEnvelope(env, wrap)
}
func (client *DecryptionClient) wrapFromEnvelope(env Envelope) (CipherDataDecrypter, error) {
f, ok := client.WrapRegistry[env.WrapAlg]
if !ok || f == nil {
return nil, awserr.New(
"InvalidWrapAlgorithmError",
"wrap algorithm isn't supported, "+env.WrapAlg,
nil,
)
}
return f(env)
}
// AESGCMNoPadding is the constant value that is used to specify
// the CEK algorithm consiting of AES GCM with no padding.
const AESGCMNoPadding = "AES/GCM/NoPadding"
// AESCBC is the string constant that signifies the AES CBC algorithm cipher.
const AESCBC = "AES/CBC"
func (client *DecryptionClient) cekFromEnvelope(env Envelope, decrypter CipherDataDecrypter) (ContentCipher, error) {
f, ok := client.CEKRegistry[env.CEKAlg]
if !ok || f == nil {
return nil, awserr.New(
"InvalidCEKAlgorithmError",
"cek algorithm isn't supported, "+env.CEKAlg,
nil,
)
}
key, err := base64.StdEncoding.DecodeString(env.CipherKey)
if err != nil {
return nil, err
}
iv, err := base64.StdEncoding.DecodeString(env.IV)
if err != nil {
return nil, err
}
key, err = decrypter.DecryptKey(key)
if err != nil {
return nil, err
}
cd := CipherData{
Key: key,
IV: iv,
CEKAlgorithm: env.CEKAlg,
Padder: client.getPadder(env.CEKAlg),
}
return f(cd)
}
// getPadder will return an unpadder with checking the cek algorithm specific padder.
// If there wasn't a cek algorithm specific padder, we check the padder itself.
// We return a no unpadder, if no unpadder was found. This means any customization
// either contained padding within the cipher implementation, and to maintain
// backwards compatility we will simply not unpad anything.
func (client *DecryptionClient) getPadder(cekAlg string) Padder {
padder, ok := client.PadderRegistry[cekAlg]
if !ok {
padder, ok = client.PadderRegistry[cekAlg[strings.LastIndex(cekAlg, "/")+1:]]
if !ok {
return NoPadder
}
}
return padder
}
func encodeMeta(reader hashReader, cd CipherData) (Envelope, error) {
iv := base64.StdEncoding.EncodeToString(cd.IV)
key := base64.StdEncoding.EncodeToString(cd.EncryptedKey)
md5 := reader.GetValue()
contentLength := reader.GetContentLength()
md5Str := base64.StdEncoding.EncodeToString(md5)
matdesc, err := cd.MaterialDescription.encodeDescription()
if err != nil {
return Envelope{}, err
}
return Envelope{
CipherKey: key,
IV: iv,
MatDesc: string(matdesc),
WrapAlg: cd.WrapAlgorithm,
CEKAlg: cd.CEKAlgorithm,
TagLen: cd.TagLength,
UnencryptedMD5: md5Str,
UnencryptedContentLen: strconv.FormatInt(contentLength, 10),
}, nil
}

View File

@@ -0,0 +1,267 @@
package s3crypto
import (
"encoding/base64"
"encoding/hex"
"fmt"
"net/http"
"net/http/httptest"
"testing"
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/awstesting/unit"
"github.com/aws/aws-sdk-go/service/kms"
)
func TestWrapFactory(t *testing.T) {
c := DecryptionClient{
WrapRegistry: map[string]WrapEntry{
KMSWrap: (kmsKeyHandler{
kms: kms.New(unit.Session),
}).decryptHandler,
},
CEKRegistry: map[string]CEKEntry{
AESGCMNoPadding: newAESGCMContentCipher,
},
}
env := Envelope{
WrapAlg: KMSWrap,
MatDesc: `{"kms_cmk_id":""}`,
}
wrap, err := c.wrapFromEnvelope(env)
w, ok := wrap.(*kmsKeyHandler)
if err != nil {
t.Errorf("expected no error, but received %v", err)
}
if wrap == nil {
t.Error("expected non-nil value")
}
if !ok {
t.Errorf("expected kmsKeyHandler, but received %v", *w)
}
}
func TestWrapFactoryErrorNoWrap(t *testing.T) {
c := DecryptionClient{
WrapRegistry: map[string]WrapEntry{
KMSWrap: (kmsKeyHandler{
kms: kms.New(unit.Session),
}).decryptHandler,
},
CEKRegistry: map[string]CEKEntry{
AESGCMNoPadding: newAESGCMContentCipher,
},
}
env := Envelope{
WrapAlg: "none",
MatDesc: `{"kms_cmk_id":""}`,
}
wrap, err := c.wrapFromEnvelope(env)
if err == nil {
t.Error("expected error, but received none")
}
if wrap != nil {
t.Errorf("expected nil wrap value, received %v", wrap)
}
}
func TestWrapFactoryCustomEntry(t *testing.T) {
c := DecryptionClient{
WrapRegistry: map[string]WrapEntry{
"custom": (kmsKeyHandler{
kms: kms.New(unit.Session),
}).decryptHandler,
},
CEKRegistry: map[string]CEKEntry{
AESGCMNoPadding: newAESGCMContentCipher,
},
}
env := Envelope{
WrapAlg: "custom",
MatDesc: `{"kms_cmk_id":""}`,
}
wrap, err := c.wrapFromEnvelope(env)
if err != nil {
t.Errorf("expected no error, but received %v", err)
}
if wrap == nil {
t.Errorf("expected nil wrap value, received %v", wrap)
}
}
func TestCEKFactory(t *testing.T) {
key, _ := hex.DecodeString("31bdadd96698c204aa9ce1448ea94ae1fb4a9a0b3c9d773b51bb1822666b8f22")
keyB64 := base64.URLEncoding.EncodeToString(key)
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
fmt.Fprintln(w, fmt.Sprintf("%s%s%s", `{"KeyId":"test-key-id","Plaintext":"`, keyB64, `"}`))
}))
defer ts.Close()
sess := unit.Session.Copy(&aws.Config{
MaxRetries: aws.Int(0),
Endpoint: aws.String(ts.URL),
DisableSSL: aws.Bool(true),
S3ForcePathStyle: aws.Bool(true),
Region: aws.String("us-west-2"),
})
c := DecryptionClient{
WrapRegistry: map[string]WrapEntry{
KMSWrap: (kmsKeyHandler{
kms: kms.New(sess),
}).decryptHandler,
},
CEKRegistry: map[string]CEKEntry{
AESGCMNoPadding: newAESGCMContentCipher,
},
PadderRegistry: map[string]Padder{
NoPadder.Name(): NoPadder,
},
}
iv, err := hex.DecodeString("0d18e06c7c725ac9e362e1ce")
if err != nil {
t.Errorf("expected no error, but received %v", err)
}
ivB64 := base64.URLEncoding.EncodeToString(iv)
cipherKey, err := hex.DecodeString("31bdadd96698c204aa9ce1448ea94ae1fb4a9a0b3c9d773b51bb1822666b8f22")
if err != nil {
t.Errorf("expected no error, but received %v", err)
}
cipherKeyB64 := base64.URLEncoding.EncodeToString(cipherKey)
env := Envelope{
WrapAlg: KMSWrap,
CEKAlg: AESGCMNoPadding,
CipherKey: cipherKeyB64,
IV: ivB64,
MatDesc: `{"kms_cmk_id":""}`,
}
wrap, err := c.wrapFromEnvelope(env)
cek, err := c.cekFromEnvelope(env, wrap)
if err != nil {
t.Errorf("expected no error, but received %v", err)
}
if cek == nil {
t.Errorf("expected non-nil cek")
}
}
func TestCEKFactoryNoCEK(t *testing.T) {
key, _ := hex.DecodeString("31bdadd96698c204aa9ce1448ea94ae1fb4a9a0b3c9d773b51bb1822666b8f22")
keyB64 := base64.URLEncoding.EncodeToString(key)
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
fmt.Fprintln(w, fmt.Sprintf("%s%s%s", `{"KeyId":"test-key-id","Plaintext":"`, keyB64, `"}`))
}))
defer ts.Close()
sess := unit.Session.Copy(&aws.Config{
MaxRetries: aws.Int(0),
Endpoint: aws.String(ts.URL),
DisableSSL: aws.Bool(true),
S3ForcePathStyle: aws.Bool(true),
Region: aws.String("us-west-2"),
})
c := DecryptionClient{
WrapRegistry: map[string]WrapEntry{
KMSWrap: (kmsKeyHandler{
kms: kms.New(sess),
}).decryptHandler,
},
CEKRegistry: map[string]CEKEntry{
AESGCMNoPadding: newAESGCMContentCipher,
},
PadderRegistry: map[string]Padder{
NoPadder.Name(): NoPadder,
},
}
iv, err := hex.DecodeString("0d18e06c7c725ac9e362e1ce")
if err != nil {
t.Errorf("expected no error, but received %v", err)
}
ivB64 := base64.URLEncoding.EncodeToString(iv)
cipherKey, err := hex.DecodeString("31bdadd96698c204aa9ce1448ea94ae1fb4a9a0b3c9d773b51bb1822666b8f22")
if err != nil {
t.Errorf("expected no error, but received %v", err)
}
cipherKeyB64 := base64.URLEncoding.EncodeToString(cipherKey)
env := Envelope{
WrapAlg: KMSWrap,
CEKAlg: "none",
CipherKey: cipherKeyB64,
IV: ivB64,
MatDesc: `{"kms_cmk_id":""}`,
}
wrap, err := c.wrapFromEnvelope(env)
cek, err := c.cekFromEnvelope(env, wrap)
if err == nil {
t.Error("expected error, but received none")
}
if cek != nil {
t.Errorf("expected nil cek value, received %v", wrap)
}
}
func TestCEKFactoryCustomEntry(t *testing.T) {
key, _ := hex.DecodeString("31bdadd96698c204aa9ce1448ea94ae1fb4a9a0b3c9d773b51bb1822666b8f22")
keyB64 := base64.URLEncoding.EncodeToString(key)
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
fmt.Fprintln(w, fmt.Sprintf("%s%s%s", `{"KeyId":"test-key-id","Plaintext":"`, keyB64, `"}`))
}))
defer ts.Close()
sess := unit.Session.Copy(&aws.Config{
MaxRetries: aws.Int(0),
Endpoint: aws.String(ts.URL),
DisableSSL: aws.Bool(true),
S3ForcePathStyle: aws.Bool(true),
Region: aws.String("us-west-2"),
})
c := DecryptionClient{
WrapRegistry: map[string]WrapEntry{
KMSWrap: (kmsKeyHandler{
kms: kms.New(sess),
}).decryptHandler,
},
CEKRegistry: map[string]CEKEntry{
"custom": newAESGCMContentCipher,
},
PadderRegistry: map[string]Padder{},
}
iv, err := hex.DecodeString("0d18e06c7c725ac9e362e1ce")
if err != nil {
t.Errorf("expected no error, but received %v", err)
}
ivB64 := base64.URLEncoding.EncodeToString(iv)
cipherKey, err := hex.DecodeString("31bdadd96698c204aa9ce1448ea94ae1fb4a9a0b3c9d773b51bb1822666b8f22")
if err != nil {
t.Errorf("expected no error, but received %v", err)
}
cipherKeyB64 := base64.URLEncoding.EncodeToString(cipherKey)
env := Envelope{
WrapAlg: KMSWrap,
CEKAlg: "custom",
CipherKey: cipherKeyB64,
IV: ivB64,
MatDesc: `{"kms_cmk_id":""}`,
}
wrap, err := c.wrapFromEnvelope(env)
cek, err := c.cekFromEnvelope(env, wrap)
if err != nil {
t.Errorf("expected no error, but received %v", err)
}
if cek == nil {
t.Errorf("expected non-nil cek")
}
}

View File

@@ -0,0 +1,135 @@
package s3crypto
import (
"strings"
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/client"
"github.com/aws/aws-sdk-go/aws/request"
"github.com/aws/aws-sdk-go/service/kms"
"github.com/aws/aws-sdk-go/service/s3"
"github.com/aws/aws-sdk-go/service/s3/s3iface"
)
// WrapEntry is builder that return a proper key decrypter and error
type WrapEntry func(Envelope) (CipherDataDecrypter, error)
// CEKEntry is a builder thatn returns a proper content decrypter and error
type CEKEntry func(CipherData) (ContentCipher, error)
// DecryptionClient is an S3 crypto client. The decryption client
// will handle all get object requests from Amazon S3.
// Supported key wrapping algorithms:
// *AWS KMS
//
// Supported content ciphers:
// * AES/GCM
// * AES/CBC
type DecryptionClient struct {
S3Client s3iface.S3API
// LoadStrategy is used to load the metadata either from the metadata of the object
// or from a separate file in s3.
//
// Defaults to our default load strategy.
LoadStrategy LoadStrategy
WrapRegistry map[string]WrapEntry
CEKRegistry map[string]CEKEntry
PadderRegistry map[string]Padder
}
// NewDecryptionClient instantiates a new S3 crypto client
//
// Example:
// sess := session.New()
// svc := s3crypto.NewDecryptionClient(sess, func(svc *s3crypto.DecryptionClient{
// // Custom client options here
// }))
func NewDecryptionClient(prov client.ConfigProvider, options ...func(*DecryptionClient)) *DecryptionClient {
s3client := s3.New(prov)
client := &DecryptionClient{
S3Client: s3client,
LoadStrategy: defaultV2LoadStrategy{
client: s3client,
},
WrapRegistry: map[string]WrapEntry{
KMSWrap: (kmsKeyHandler{
kms: kms.New(prov),
}).decryptHandler,
},
CEKRegistry: map[string]CEKEntry{
AESGCMNoPadding: newAESGCMContentCipher,
strings.Join([]string{AESCBC, AESCBCPadder.Name()}, "/"): newAESCBCContentCipher,
},
PadderRegistry: map[string]Padder{
strings.Join([]string{AESCBC, AESCBCPadder.Name()}, "/"): AESCBCPadder,
"NoPadding": NoPadder,
},
}
for _, option := range options {
option(client)
}
return client
}
// GetObjectRequest will make a request to s3 and retrieve the object. In this process
// decryption will be done. The SDK only supports V2 reads of KMS and GCM.
//
// Example:
// sess := session.New()
// svc := s3crypto.NewDecryptionClient(sess)
// req, out := svc.GetObjectRequest(&s3.GetObjectInput {
// Key: aws.String("testKey"),
// Bucket: aws.String("testBucket"),
// })
// err := req.Send()
func (c *DecryptionClient) GetObjectRequest(input *s3.GetObjectInput) (*request.Request, *s3.GetObjectOutput) {
req, out := c.S3Client.GetObjectRequest(input)
req.Handlers.Unmarshal.PushBack(func(r *request.Request) {
env, err := c.LoadStrategy.Load(r)
if err != nil {
r.Error = err
out.Body.Close()
return
}
// If KMS should return the correct CEK algorithm with the proper
// KMS key provider
cipher, err := c.contentCipherFromEnvelope(env)
if err != nil {
r.Error = err
out.Body.Close()
return
}
reader, err := cipher.DecryptContents(out.Body)
if err != nil {
r.Error = err
out.Body.Close()
return
}
out.Body = reader
})
return req, out
}
// GetObject is a wrapper for GetObjectRequest
func (c *DecryptionClient) GetObject(input *s3.GetObjectInput) (*s3.GetObjectOutput, error) {
req, out := c.GetObjectRequest(input)
return out, req.Send()
}
// GetObjectWithContext is a wrapper for GetObjectRequest with the additional
// context, and request options support.
//
// GetObjectWithContext is the same as GetObject with the additional support for
// Context input parameters. The Context must not be nil. A nil Context will
// cause a panic. Use the Context to add deadlining, timeouts, etc. In the future
// this may create sub-contexts for individual underlying requests.
func (c *DecryptionClient) GetObjectWithContext(ctx aws.Context, input *s3.GetObjectInput, opts ...request.Option) (*s3.GetObjectOutput, error) {
req, out := c.GetObjectRequest(input)
req.SetContext(ctx)
req.ApplyOptions(opts...)
return out, req.Send()
}

View File

@@ -0,0 +1,250 @@
package s3crypto_test
import (
"bytes"
"encoding/base64"
"encoding/hex"
"fmt"
"io/ioutil"
"net/http"
"net/http/httptest"
"strings"
"testing"
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/awserr"
"github.com/aws/aws-sdk-go/aws/request"
"github.com/aws/aws-sdk-go/awstesting"
"github.com/aws/aws-sdk-go/awstesting/unit"
"github.com/aws/aws-sdk-go/service/s3"
"github.com/aws/aws-sdk-go/service/s3/s3crypto"
)
func TestGetObjectGCM(t *testing.T) {
key, _ := hex.DecodeString("31bdadd96698c204aa9ce1448ea94ae1fb4a9a0b3c9d773b51bb1822666b8f22")
keyB64 := base64.StdEncoding.EncodeToString(key)
// This is our KMS response
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
fmt.Fprintln(w, fmt.Sprintf("%s%s%s", `{"KeyId":"test-key-id","Plaintext":"`, keyB64, `"}`))
}))
defer ts.Close()
sess := unit.Session.Copy(&aws.Config{
MaxRetries: aws.Int(0),
Endpoint: aws.String(ts.URL),
DisableSSL: aws.Bool(true),
S3ForcePathStyle: aws.Bool(true),
Region: aws.String("us-west-2"),
})
c := s3crypto.NewDecryptionClient(sess)
if c == nil {
t.Error("expected non-nil value")
}
input := &s3.GetObjectInput{
Key: aws.String("test"),
Bucket: aws.String("test"),
}
req, out := c.GetObjectRequest(input)
req.Handlers.Send.Clear()
req.Handlers.Send.PushBack(func(r *request.Request) {
iv, err := hex.DecodeString("0d18e06c7c725ac9e362e1ce")
if err != nil {
t.Errorf("expected no error, but received %v", err)
}
b, err := hex.DecodeString("fa4362189661d163fcd6a56d8bf0405ad636ac1bbedd5cc3ee727dc2ab4a9489")
if err != nil {
t.Errorf("expected no error, but received %v", err)
}
r.HTTPResponse = &http.Response{
StatusCode: 200,
Header: http.Header{
http.CanonicalHeaderKey("x-amz-meta-x-amz-key-v2"): []string{"SpFRES0JyU8BLZSKo51SrwILK4lhtZsWiMNjgO4WmoK+joMwZPG7Hw=="},
http.CanonicalHeaderKey("x-amz-meta-x-amz-iv"): []string{base64.URLEncoding.EncodeToString(iv)},
http.CanonicalHeaderKey("x-amz-meta-x-amz-matdesc"): []string{`{"kms_cmk_id":"arn:aws:kms:us-east-1:172259396726:key/a22a4b30-79f4-4b3d-bab4-a26d327a231b"}`},
http.CanonicalHeaderKey("x-amz-meta-x-amz-wrap-alg"): []string{s3crypto.KMSWrap},
http.CanonicalHeaderKey("x-amz-meta-x-amz-cek-alg"): []string{s3crypto.AESGCMNoPadding},
http.CanonicalHeaderKey("x-amz-meta-x-amz-tag-len"): []string{"128"},
},
Body: ioutil.NopCloser(bytes.NewBuffer(b)),
}
out.Metadata = make(map[string]*string)
out.Metadata["x-amz-wrap-alg"] = aws.String(s3crypto.KMSWrap)
})
err := req.Send()
if err != nil {
t.Errorf("expected no error, but received %v", err)
}
b, err := ioutil.ReadAll(out.Body)
if err != nil {
t.Errorf("expected no error, but received %v", err)
}
expected, err := hex.DecodeString("2db5168e932556f8089a0622981d017d")
if err != nil {
t.Errorf("expected no error, but received %v", err)
}
if !bytes.Equal(expected, b) {
t.Error("expected bytes to be equivalent")
}
}
func TestGetObjectCBC(t *testing.T) {
key, _ := hex.DecodeString("898be9cc5004ed0fa6e117c9a3099d31")
keyB64 := base64.StdEncoding.EncodeToString(key)
// This is our KMS response
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
fmt.Fprintln(w, fmt.Sprintf("%s%s%s", `{"KeyId":"test-key-id","Plaintext":"`, keyB64, `"}`))
}))
defer ts.Close()
sess := unit.Session.Copy(&aws.Config{
MaxRetries: aws.Int(0),
Endpoint: aws.String(ts.URL),
DisableSSL: aws.Bool(true),
S3ForcePathStyle: aws.Bool(true),
Region: aws.String("us-west-2"),
})
c := s3crypto.NewDecryptionClient(sess)
if c == nil {
t.Error("expected non-nil value")
}
input := &s3.GetObjectInput{
Key: aws.String("test"),
Bucket: aws.String("test"),
}
req, out := c.GetObjectRequest(input)
req.Handlers.Send.Clear()
req.Handlers.Send.PushBack(func(r *request.Request) {
iv, err := hex.DecodeString("9dea7621945988f96491083849b068df")
if err != nil {
t.Errorf("expected no error, but received %v", err)
}
b, err := hex.DecodeString("e232cd6ef50047801ee681ec30f61d53cfd6b0bca02fd03c1b234baa10ea82ac9dab8b960926433a19ce6dea08677e34")
if err != nil {
t.Errorf("expected no error, but received %v", err)
}
r.HTTPResponse = &http.Response{
StatusCode: 200,
Header: http.Header{
http.CanonicalHeaderKey("x-amz-meta-x-amz-key-v2"): []string{"SpFRES0JyU8BLZSKo51SrwILK4lhtZsWiMNjgO4WmoK+joMwZPG7Hw=="},
http.CanonicalHeaderKey("x-amz-meta-x-amz-iv"): []string{base64.URLEncoding.EncodeToString(iv)},
http.CanonicalHeaderKey("x-amz-meta-x-amz-matdesc"): []string{`{"kms_cmk_id":"arn:aws:kms:us-east-1:172259396726:key/a22a4b30-79f4-4b3d-bab4-a26d327a231b"}`},
http.CanonicalHeaderKey("x-amz-meta-x-amz-wrap-alg"): []string{s3crypto.KMSWrap},
http.CanonicalHeaderKey("x-amz-meta-x-amz-cek-alg"): []string{strings.Join([]string{s3crypto.AESCBC, s3crypto.AESCBCPadder.Name()}, "/")},
},
Body: ioutil.NopCloser(bytes.NewBuffer(b)),
}
out.Metadata = make(map[string]*string)
out.Metadata["x-amz-wrap-alg"] = aws.String(s3crypto.KMSWrap)
})
err := req.Send()
if err != nil {
t.Errorf("expected no error, but received %v", err)
}
b, err := ioutil.ReadAll(out.Body)
if err != nil {
t.Errorf("expected no error, but received %v", err)
}
expected, err := hex.DecodeString("0397f4f6820b1f9386f14403be5ac16e50213bd473b4874b9bcbf5f318ee686b1d")
if err != nil {
t.Errorf("expected no error, but received %v", err)
}
if !bytes.Equal(expected, b) {
t.Error("expected bytes to be equivalent")
}
}
func TestGetObjectCBC2(t *testing.T) {
key, _ := hex.DecodeString("8d70e92489c4e6cfb12261b4d17f4b85826da687fc8742fcf9f87fadb5b4cb89")
keyB64 := base64.StdEncoding.EncodeToString(key)
// This is our KMS response
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
fmt.Fprintln(w, fmt.Sprintf("%s%s%s", `{"KeyId":"test-key-id","Plaintext":"`, keyB64, `"}`))
}))
defer ts.Close()
sess := unit.Session.Copy(&aws.Config{
MaxRetries: aws.Int(0),
Endpoint: aws.String(ts.URL),
DisableSSL: aws.Bool(true),
S3ForcePathStyle: aws.Bool(true),
Region: aws.String("us-west-2"),
})
c := s3crypto.NewDecryptionClient(sess)
if c == nil {
t.Error("expected non-nil value")
}
input := &s3.GetObjectInput{
Key: aws.String("test"),
Bucket: aws.String("test"),
}
req, out := c.GetObjectRequest(input)
req.Handlers.Send.Clear()
req.Handlers.Send.PushBack(func(r *request.Request) {
b, err := hex.DecodeString("fd0c71ecb7ed16a9bf42ea5f75501d416df608f190890c3b4d8897f24744cd7f9ea4a0b212e60634302450e1c5378f047ff753ccefe365d411c36339bf22e301fae4c3a6226719a4b93dc74c1af79d0296659b5d56c0892315f2c7cc30190220db1eaafae3920d6d9c65d0aa366499afc17af493454e141c6e0fbdeb6a990cb4")
if err != nil {
t.Errorf("expected no error, but received %v", err)
}
r.HTTPResponse = &http.Response{
StatusCode: 200,
Header: http.Header{
http.CanonicalHeaderKey("x-amz-meta-x-amz-key-v2"): []string{"AQEDAHikdGvcj7Gil5VqAR/JWvvPp3ue26+t2vhWy4lL2hg4mAAAAH4wfAYJKoZIhvcNAQcGoG8wbQIBADBoBgkqhkiG9w0BBwEwHgYJYIZIAWUDBAEuMBEEDCcy43wCR0bSsnzTrAIBEIA7WdD2jxC3tCrK6TOdiEfbIN64m+UN7Velz4y0LRra5jn2U1CDClacwIpiBYuDp5ymPKO+ZqUGE0WEf20="},
http.CanonicalHeaderKey("x-amz-meta-x-amz-iv"): []string{"EMMWJY8ZLcK/9FOj3iCpng=="},
http.CanonicalHeaderKey("x-amz-meta-x-amz-matdesc"): []string{`{"kms_cmk_id":"arn:aws:kms:us-east-1:172259396726:key/a22a4b30-79f4-4b3d-bab4-a26d327a231b"}`},
http.CanonicalHeaderKey("x-amz-meta-x-amz-wrap-alg"): []string{s3crypto.KMSWrap},
http.CanonicalHeaderKey("x-amz-meta-x-amz-cek-alg"): []string{strings.Join([]string{s3crypto.AESCBC, s3crypto.AESCBCPadder.Name()}, "/")},
},
Body: ioutil.NopCloser(bytes.NewBuffer(b)),
}
out.Metadata = make(map[string]*string)
out.Metadata["x-amz-wrap-alg"] = aws.String(s3crypto.KMSWrap)
})
err := req.Send()
if err != nil {
t.Errorf("expected no error, but received %v", err)
}
b, err := ioutil.ReadAll(out.Body)
if err != nil {
t.Errorf("expected no error, but received %v", err)
}
expected, err := hex.DecodeString("a6ccd3482f5ce25c9ddeb69437cd0acbc0bdda2ef8696d90781de2b35704543529871b2032e68ef1c5baed1769aba8d420d1aca181341b49b8b3587a6580cdf1d809c68f06735f7735c16691f4b70c967d68fc08195b81ad71bcc4df452fd0a5799c1e1234f92f1cd929fc072167ccf9f2ac85b93170932b32")
if err != nil {
t.Errorf("expected no error, but received %v", err)
}
if !bytes.Equal(expected, b) {
t.Error("expected bytes to be equivalent")
}
}
func TestGetObjectWithContext(t *testing.T) {
c := s3crypto.NewDecryptionClient(unit.Session)
ctx := &awstesting.FakeContext{DoneCh: make(chan struct{})}
ctx.Error = fmt.Errorf("context canceled")
close(ctx.DoneCh)
input := s3.GetObjectInput{
Key: aws.String("test"),
Bucket: aws.String("test"),
}
_, err := c.GetObjectWithContext(ctx, &input)
if err == nil {
t.Fatalf("expected error, did not get one")
}
aerr := err.(awserr.Error)
if e, a := request.CanceledErrorCode, aerr.Code(); e != a {
t.Errorf("expected error code %q, got %q", e, a)
}
if e, a := "canceled", aerr.Message(); !strings.Contains(a, e) {
t.Errorf("expected error message to contain %q, but did not %q", e, a)
}
}

View File

@@ -0,0 +1,66 @@
/*
Package s3crypto provides encryption to S3 using KMS and AES GCM.
Keyproviders are interfaces that handle masterkeys. Masterkeys are used to encrypt and decrypt the randomly
generated cipher keys. The SDK currently uses KMS to do this. A user does not need to provide a master key
since all that information is hidden in KMS.
Modes are interfaces that handle content encryption and decryption. It is an abstraction layer that instantiates
the ciphers. If content is being encrypted we generate the key and iv of the cipher. For decryption, we use the
metadata stored either on the object or an instruction file object to decrypt the contents.
Ciphers are interfaces that handle encryption and decryption of data. This may be key wrap ciphers or content
ciphers.
Creating an S3 cryptography client
cmkID := "<some key ID>"
sess := session.New()
// Create the KeyProvider
handler := s3crypto.NewKMSKeyGenerator(kms.New(sess), cmkID)
// Create an encryption and decryption client
// We need to pass the session here so S3 can use it. In addition, any decryption that
// occurs will use the KMS client.
svc := s3crypto.NewEncryptionClient(sess, s3crypto.AESGCMContentCipherBuilder(handler))
svc := s3crypto.NewDecryptionClient(sess)
Configuration of the S3 cryptography client
cfg := s3crypto.EncryptionConfig{
// Save instruction files to separate objects
SaveStrategy: NewS3SaveStrategy(session.New(), ""),
// Change instruction file suffix to .example
InstructionFileSuffix: ".example",
// Set temp folder path
TempFolderPath: "/path/to/tmp/folder/",
// Any content less than the minimum file size will use memory
// instead of writing the contents to a temp file.
MinFileSize: int64(1024 * 1024 * 1024),
}
The default SaveStrategy is to the object's header.
The InstructionFileSuffix defaults to .instruction. Careful here though, if you do this, be sure you know
what that suffix is in grabbing data. All requests will look for fooKey.example instead of fooKey.instruction.
This suffix only affects gets and not puts. Put uses the keyprovider's suffix.
Registration of new wrap or cek algorithms are also supported by the SDK. Let's say we want to support `AES Wrap`
and `AES CTR`. Let's assume we have already defined the functionality.
svc := s3crypto.NewDecryptionClient(sess)
svc.WrapRegistry["AESWrap"] = NewAESWrap
svc.CEKRegistry["AES/CTR/NoPadding"] = NewAESCTR
We have now registered these new algorithms to the decryption client. When the client calls `GetObject` and sees
the wrap as `AESWrap` then it'll use that wrap algorithm. This is also true for `AES/CTR/NoPadding`.
For encryption adding a custom content cipher builder and key handler will allow for encryption of custom
defined ciphers.
// Our wrap algorithm, AESWrap
handler := NewAESWrap(key, iv)
// Our content cipher builder, AESCTRContentCipherBuilder
svc := s3crypto.NewEncryptionClient(sess, NewAESCTRContentCipherBuilder(handler))
*/
package s3crypto

View File

@@ -0,0 +1,146 @@
package s3crypto
import (
"encoding/hex"
"io"
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/client"
"github.com/aws/aws-sdk-go/aws/request"
"github.com/aws/aws-sdk-go/internal/sdkio"
"github.com/aws/aws-sdk-go/service/s3"
"github.com/aws/aws-sdk-go/service/s3/s3iface"
)
// DefaultMinFileSize is used to check whether we want to write to a temp file
// or store the data in memory.
const DefaultMinFileSize = 1024 * 512 * 5
// EncryptionClient is an S3 crypto client. By default the SDK will use Authentication mode which
// will use KMS for key wrapping and AES GCM for content encryption.
// AES GCM will load all data into memory. However, the rest of the content algorithms
// do not load the entire contents into memory.
type EncryptionClient struct {
S3Client s3iface.S3API
ContentCipherBuilder ContentCipherBuilder
// SaveStrategy will dictate where the envelope is saved.
//
// Defaults to the object's metadata
SaveStrategy SaveStrategy
// TempFolderPath is used to store temp files when calling PutObject.
// Temporary files are needed to compute the X-Amz-Content-Sha256 header.
TempFolderPath string
// MinFileSize is the minimum size for the content to write to a
// temporary file instead of using memory.
MinFileSize int64
}
// NewEncryptionClient instantiates a new S3 crypto client
//
// Example:
// cmkID := "arn:aws:kms:region:000000000000:key/00000000-0000-0000-0000-000000000000"
// sess := session.New()
// handler := s3crypto.NewKMSKeyGenerator(kms.New(sess), cmkID)
// svc := s3crypto.New(sess, s3crypto.AESGCMContentCipherBuilder(handler))
func NewEncryptionClient(prov client.ConfigProvider, builder ContentCipherBuilder, options ...func(*EncryptionClient)) *EncryptionClient {
client := &EncryptionClient{
S3Client: s3.New(prov),
ContentCipherBuilder: builder,
SaveStrategy: HeaderV2SaveStrategy{},
MinFileSize: DefaultMinFileSize,
}
for _, option := range options {
option(client)
}
return client
}
// PutObjectRequest creates a temp file to encrypt the contents into. It then streams
// that data to S3.
//
// Example:
// svc := s3crypto.New(session.New(), s3crypto.AESGCMContentCipherBuilder(handler))
// req, out := svc.PutObjectRequest(&s3.PutObjectInput {
// Key: aws.String("testKey"),
// Bucket: aws.String("testBucket"),
// Body: strings.NewReader("test data"),
// })
// err := req.Send()
func (c *EncryptionClient) PutObjectRequest(input *s3.PutObjectInput) (*request.Request, *s3.PutObjectOutput) {
req, out := c.S3Client.PutObjectRequest(input)
// Get Size of file
n, err := aws.SeekerLen(input.Body)
if err != nil {
req.Error = err
return req, out
}
dst, err := getWriterStore(req, c.TempFolderPath, n >= c.MinFileSize)
if err != nil {
req.Error = err
return req, out
}
encryptor, err := c.ContentCipherBuilder.ContentCipher()
req.Handlers.Build.PushFront(func(r *request.Request) {
if err != nil {
r.Error = err
return
}
md5 := newMD5Reader(input.Body)
sha := newSHA256Writer(dst)
reader, err := encryptor.EncryptContents(md5)
if err != nil {
r.Error = err
return
}
_, err = io.Copy(sha, reader)
if err != nil {
r.Error = err
return
}
data := encryptor.GetCipherData()
env, err := encodeMeta(md5, data)
if err != nil {
r.Error = err
return
}
shaHex := hex.EncodeToString(sha.GetValue())
req.HTTPRequest.Header.Set("X-Amz-Content-Sha256", shaHex)
dst.Seek(0, sdkio.SeekStart)
input.Body = dst
err = c.SaveStrategy.Save(env, r)
r.Error = err
})
return req, out
}
// PutObject is a wrapper for PutObjectRequest
func (c *EncryptionClient) PutObject(input *s3.PutObjectInput) (*s3.PutObjectOutput, error) {
req, out := c.PutObjectRequest(input)
return out, req.Send()
}
// PutObjectWithContext is a wrapper for PutObjectRequest with the additional
// context, and request options support.
//
// PutObjectWithContext is the same as PutObject with the additional support for
// Context input parameters. The Context must not be nil. A nil Context will
// cause a panic. Use the Context to add deadlining, timeouts, etc. In the future
// this may create sub-contexts for individual underlying requests.
func (c *EncryptionClient) PutObjectWithContext(ctx aws.Context, input *s3.PutObjectInput, opts ...request.Option) (*s3.PutObjectOutput, error) {
req, out := c.PutObjectRequest(input)
req.SetContext(ctx)
req.ApplyOptions(opts...)
return out, req.Send()
}

View File

@@ -0,0 +1,111 @@
package s3crypto_test
import (
"bytes"
"errors"
"fmt"
"io/ioutil"
"net/http"
"strings"
"testing"
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/awserr"
"github.com/aws/aws-sdk-go/aws/request"
"github.com/aws/aws-sdk-go/awstesting"
"github.com/aws/aws-sdk-go/awstesting/unit"
"github.com/aws/aws-sdk-go/service/kms"
"github.com/aws/aws-sdk-go/service/s3"
"github.com/aws/aws-sdk-go/service/s3/s3crypto"
)
func TestDefaultConfigValues(t *testing.T) {
sess := unit.Session.Copy(&aws.Config{
MaxRetries: aws.Int(0),
S3ForcePathStyle: aws.Bool(true),
Region: aws.String("us-west-2"),
})
svc := kms.New(sess)
handler := s3crypto.NewKMSKeyGenerator(svc, "testid")
c := s3crypto.NewEncryptionClient(sess, s3crypto.AESGCMContentCipherBuilder(handler))
if c == nil {
t.Error("expected non-vil client value")
}
if c.ContentCipherBuilder == nil {
t.Error("expected non-vil content cipher builder value")
}
if c.SaveStrategy == nil {
t.Error("expected non-vil save strategy value")
}
}
func TestPutObject(t *testing.T) {
size := 1024 * 1024
data := make([]byte, size)
expected := bytes.Repeat([]byte{1}, size)
generator := mockGenerator{}
cb := mockCipherBuilder{generator}
sess := unit.Session.Copy(&aws.Config{
MaxRetries: aws.Int(0),
S3ForcePathStyle: aws.Bool(true),
Region: aws.String("us-west-2"),
})
c := s3crypto.NewEncryptionClient(sess, cb)
if c == nil {
t.Error("expected non-vil client value")
}
input := &s3.PutObjectInput{
Key: aws.String("test"),
Bucket: aws.String("test"),
Body: bytes.NewReader(data),
}
req, _ := c.PutObjectRequest(input)
req.Handlers.Send.Clear()
req.Handlers.Send.PushBack(func(r *request.Request) {
r.Error = errors.New("stop")
r.HTTPResponse = &http.Response{
StatusCode: 200,
}
})
err := req.Send()
if e, a := "stop", err.Error(); e != a {
t.Errorf("expected %s error, but received %s", e, a)
}
b, err := ioutil.ReadAll(req.HTTPRequest.Body)
if err != nil {
t.Errorf("expected no error, but received %v", err)
}
if !bytes.Equal(expected, b) {
t.Error("expected bytes to be equivalent, but received otherwise")
}
}
func TestPutObjectWithContext(t *testing.T) {
generator := mockGenerator{}
cb := mockCipherBuilder{generator}
c := s3crypto.NewEncryptionClient(unit.Session, cb)
ctx := &awstesting.FakeContext{DoneCh: make(chan struct{})}
ctx.Error = fmt.Errorf("context canceled")
close(ctx.DoneCh)
input := s3.PutObjectInput{
Bucket: aws.String("test"),
Key: aws.String("test"),
Body: bytes.NewReader([]byte{}),
}
_, err := c.PutObjectWithContext(ctx, &input)
if err == nil {
t.Fatalf("expected error, did not get one")
}
aerr := err.(awserr.Error)
if e, a := request.CanceledErrorCode, aerr.Code(); e != a {
t.Errorf("expected error code %q, got %q", e, a)
}
if e, a := "canceled", aerr.Message(); !strings.Contains(a, e) {
t.Errorf("expected error message to contain %q, but did not %q", e, a)
}
}

View File

@@ -0,0 +1,37 @@
package s3crypto
// DefaultInstructionKeySuffix is appended to the end of the instruction file key when
// grabbing or saving to S3
const DefaultInstructionKeySuffix = ".instruction"
const (
metaHeader = "x-amz-meta"
keyV1Header = "x-amz-key"
keyV2Header = keyV1Header + "-v2"
ivHeader = "x-amz-iv"
matDescHeader = "x-amz-matdesc"
cekAlgorithmHeader = "x-amz-cek-alg"
wrapAlgorithmHeader = "x-amz-wrap-alg"
tagLengthHeader = "x-amz-tag-len"
unencryptedMD5Header = "x-amz-unencrypted-content-md5"
unencryptedContentLengthHeader = "x-amz-unencrypted-content-length"
)
// Envelope encryption starts off by generating a random symmetric key using
// AES GCM. The SDK generates a random IV based off the encryption cipher
// chosen. The master key that was provided, whether by the user or KMS, will be used
// to encrypt the randomly generated symmetric key and base64 encode the iv. This will
// allow for decryption of that same data later.
type Envelope struct {
// IV is the randomly generated IV base64 encoded.
IV string `json:"x-amz-iv"`
// CipherKey is the randomly generated cipher key.
CipherKey string `json:"x-amz-key-v2"`
// MaterialDesc is a description to distinguish from other envelopes.
MatDesc string `json:"x-amz-matdesc"`
WrapAlg string `json:"x-amz-wrap-alg"`
CEKAlg string `json:"x-amz-cek-alg"`
TagLen string `json:"x-amz-tag-len"`
UnencryptedMD5 string `json:"x-amz-unencrypted-content-md5"`
UnencryptedContentLen string `json:"x-amz-unencrypted-content-length"`
}

View File

@@ -0,0 +1,61 @@
package s3crypto
import (
"crypto/md5"
"crypto/sha256"
"hash"
"io"
)
// hashReader is used for calculating SHA256 when following the sigv4 specification.
// Additionally this used for calculating the unencrypted MD5.
type hashReader interface {
GetValue() []byte
GetContentLength() int64
}
type sha256Writer struct {
sha256 []byte
hash hash.Hash
out io.Writer
}
func newSHA256Writer(f io.Writer) *sha256Writer {
return &sha256Writer{hash: sha256.New(), out: f}
}
func (r *sha256Writer) Write(b []byte) (int, error) {
r.hash.Write(b)
return r.out.Write(b)
}
func (r *sha256Writer) GetValue() []byte {
return r.hash.Sum(nil)
}
type md5Reader struct {
contentLength int64
hash hash.Hash
body io.Reader
}
func newMD5Reader(body io.Reader) *md5Reader {
return &md5Reader{hash: md5.New(), body: body}
}
func (w *md5Reader) Read(b []byte) (int, error) {
n, err := w.body.Read(b)
if err != nil && err != io.EOF {
return n, err
}
w.contentLength += int64(n)
w.hash.Write(b[:n])
return n, err
}
func (w *md5Reader) GetValue() []byte {
return w.hash.Sum(nil)
}
func (w *md5Reader) GetContentLength() int64 {
return w.contentLength
}

View File

@@ -0,0 +1,29 @@
package s3crypto
import (
"bytes"
"encoding/hex"
"testing"
)
// From Go stdlib encoding/sha256 test cases
func TestSHA256(t *testing.T) {
sha := newSHA256Writer(nil)
expected, _ := hex.DecodeString("e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855")
b := sha.GetValue()
if !bytes.Equal(expected, b) {
t.Errorf("expected equivalent sha values, but received otherwise")
}
}
func TestSHA256_Case2(t *testing.T) {
sha := newSHA256Writer(bytes.NewBuffer([]byte{}))
sha.Write([]byte("hello"))
expected, _ := hex.DecodeString("2cf24dba5fb0a30e26e83b2ac5b9e29e1b161e5c1fa7425e73043362938b9824")
b := sha.GetValue()
if !bytes.Equal(expected, b) {
t.Errorf("expected equivalent sha values, but received otherwise")
}
}

View File

@@ -0,0 +1,70 @@
package s3crypto
import (
"errors"
"io"
"io/ioutil"
"os"
"path/filepath"
"github.com/aws/aws-sdk-go/aws/request"
)
func getWriterStore(req *request.Request, path string, useTempFile bool) (io.ReadWriteSeeker, error) {
if !useTempFile {
return &bytesReadWriteSeeker{}, nil
}
// Create temp file to be used later for calculating the SHA256 header
f, err := ioutil.TempFile(path, "")
if err != nil {
return nil, err
}
req.Handlers.Send.PushBack(func(r *request.Request) {
// Close the temp file and cleanup
f.Close()
fpath := filepath.Join(path, f.Name())
os.Remove(fpath)
})
return f, nil
}
type bytesReadWriteSeeker struct {
buf []byte
i int64
}
// Copied from Go stdlib bytes.Reader
func (ws *bytesReadWriteSeeker) Read(b []byte) (int, error) {
if ws.i >= int64(len(ws.buf)) {
return 0, io.EOF
}
n := copy(b, ws.buf[ws.i:])
ws.i += int64(n)
return n, nil
}
func (ws *bytesReadWriteSeeker) Write(b []byte) (int, error) {
ws.buf = append(ws.buf, b...)
return len(b), nil
}
// Copied from Go stdlib bytes.Reader
func (ws *bytesReadWriteSeeker) Seek(offset int64, whence int) (int64, error) {
var abs int64
switch whence {
case 0:
abs = offset
case 1:
abs = int64(ws.i) + offset
case 2:
abs = int64(len(ws.buf)) + offset
default:
return 0, errors.New("bytes.Reader.Seek: invalid whence")
}
if abs < 0 {
return 0, errors.New("bytes.Reader.Seek: negative position")
}
ws.i = abs
return abs, nil
}

View File

@@ -0,0 +1,84 @@
package s3crypto
import (
"bytes"
"testing"
"github.com/aws/aws-sdk-go/internal/sdkio"
)
func TestBytesReadWriteSeeker_Read(t *testing.T) {
b := &bytesReadWriteSeeker{[]byte{1, 2, 3}, 0}
expected := []byte{1, 2, 3}
buf := make([]byte, 3)
n, err := b.Read(buf)
if err != nil {
t.Errorf("expected no error, but received %v", err)
}
if e, a := 3, n; e != a {
t.Errorf("expected %d, but received %d", e, a)
}
if !bytes.Equal(expected, buf) {
t.Error("expected equivalent byte slices, but received otherwise")
}
}
func TestBytesReadWriteSeeker_Write(t *testing.T) {
b := &bytesReadWriteSeeker{}
expected := []byte{1, 2, 3}
buf := make([]byte, 3)
n, err := b.Write([]byte{1, 2, 3})
if err != nil {
t.Errorf("expected no error, but received %v", err)
}
if e, a := 3, n; e != a {
t.Errorf("expected %d, but received %d", e, a)
}
n, err = b.Read(buf)
if err != nil {
t.Errorf("expected no error, but received %v", err)
}
if e, a := 3, n; e != a {
t.Errorf("expected %d, but received %d", e, a)
}
if !bytes.Equal(expected, buf) {
t.Error("expected equivalent byte slices, but received otherwise")
}
}
func TestBytesReadWriteSeeker_Seek(t *testing.T) {
b := &bytesReadWriteSeeker{[]byte{1, 2, 3}, 0}
expected := []byte{2, 3}
m, err := b.Seek(1, sdkio.SeekStart)
if err != nil {
t.Errorf("expected no error, but received %v", err)
}
if e, a := 1, int(m); e != a {
t.Errorf("expected %d, but received %d", e, a)
}
buf := make([]byte, 3)
n, err := b.Read(buf)
if err != nil {
t.Errorf("expected no error, but received %v", err)
}
if e, a := 2, n; e != a {
t.Errorf("expected %d, but received %d", e, a)
}
if !bytes.Equal(expected, buf[:n]) {
t.Error("expected equivalent byte slices, but received otherwise")
}
}

View File

@@ -0,0 +1,21 @@
package s3crypto
import "crypto/rand"
// CipherDataGenerator handles generating proper key and IVs of proper size for the
// content cipher. CipherDataGenerator will also encrypt the key and store it in
// the CipherData.
type CipherDataGenerator interface {
GenerateCipherData(int, int) (CipherData, error)
}
// CipherDataDecrypter is a handler to decrypt keys from the envelope.
type CipherDataDecrypter interface {
DecryptKey([]byte) ([]byte, error)
}
func generateBytes(n int) []byte {
b := make([]byte, n)
rand.Read(b)
return b
}

View File

@@ -0,0 +1,20 @@
package s3crypto
import (
"testing"
)
func TestGenerateBytes(t *testing.T) {
b := generateBytes(5)
if e, a := 5, len(b); e != a {
t.Errorf("expected %d, but received %d", e, a)
}
b = generateBytes(0)
if e, a := 0, len(b); e != a {
t.Errorf("expected %d, but received %d", e, a)
}
b = generateBytes(1024)
if e, a := 1024, len(b); e != a {
t.Errorf("expected %d, but received %d", e, a)
}
}

View File

@@ -0,0 +1,133 @@
package s3crypto
import (
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/awserr"
"github.com/aws/aws-sdk-go/service/kms"
"github.com/aws/aws-sdk-go/service/kms/kmsiface"
)
const (
// KMSWrap is a constant used during decryption to build a KMS key handler.
KMSWrap = "kms"
)
// kmsKeyHandler will make calls to KMS to get the masterkey
type kmsKeyHandler struct {
kms kmsiface.KMSAPI
cmkID *string
CipherData
}
// NewKMSKeyGenerator builds a new KMS key provider using the customer key ID and material
// description.
//
// Example:
// sess := session.New(&aws.Config{})
// cmkID := "arn to key"
// matdesc := s3crypto.MaterialDescription{}
// handler := s3crypto.NewKMSKeyGenerator(kms.New(sess), cmkID)
func NewKMSKeyGenerator(kmsClient kmsiface.KMSAPI, cmkID string) CipherDataGenerator {
return NewKMSKeyGeneratorWithMatDesc(kmsClient, cmkID, MaterialDescription{})
}
// NewKMSKeyGeneratorWithMatDesc builds a new KMS key provider using the customer key ID and material
// description.
//
// Example:
// sess := session.New(&aws.Config{})
// cmkID := "arn to key"
// matdesc := s3crypto.MaterialDescription{}
// handler := s3crypto.NewKMSKeyGeneratorWithMatDesc(kms.New(sess), cmkID, matdesc)
func NewKMSKeyGeneratorWithMatDesc(kmsClient kmsiface.KMSAPI, cmkID string, matdesc MaterialDescription) CipherDataGenerator {
if matdesc == nil {
matdesc = MaterialDescription{}
}
matdesc["kms_cmk_id"] = &cmkID
// These values are read only making them thread safe
kp := &kmsKeyHandler{
kms: kmsClient,
cmkID: &cmkID,
}
// These values are read only making them thread safe
kp.CipherData.WrapAlgorithm = KMSWrap
kp.CipherData.MaterialDescription = matdesc
return kp
}
// NewKMSWrapEntry builds returns a new KMS key provider and its decrypt handler.
//
// Example:
// sess := session.New(&aws.Config{})
// customKMSClient := kms.New(sess)
// decryptHandler := s3crypto.NewKMSWrapEntry(customKMSClient)
//
// svc := s3crypto.NewDecryptionClient(sess, func(svc *s3crypto.DecryptionClient{
// svc.WrapRegistry[KMSWrap] = decryptHandler
// }))
func NewKMSWrapEntry(kmsClient kmsiface.KMSAPI) WrapEntry {
// These values are read only making them thread safe
kp := &kmsKeyHandler{
kms: kmsClient,
}
return kp.decryptHandler
}
// decryptHandler initializes a KMS keyprovider with a material description. This
// is used with Decrypting kms content, due to the cmkID being in the material description.
func (kp kmsKeyHandler) decryptHandler(env Envelope) (CipherDataDecrypter, error) {
m := MaterialDescription{}
err := m.decodeDescription([]byte(env.MatDesc))
if err != nil {
return nil, err
}
cmkID, ok := m["kms_cmk_id"]
if !ok {
return nil, awserr.New("MissingCMKIDError", "Material description is missing CMK ID", nil)
}
kp.CipherData.MaterialDescription = m
kp.cmkID = cmkID
kp.WrapAlgorithm = KMSWrap
return &kp, nil
}
// DecryptKey makes a call to KMS to decrypt the key.
func (kp *kmsKeyHandler) DecryptKey(key []byte) ([]byte, error) {
out, err := kp.kms.Decrypt(&kms.DecryptInput{
EncryptionContext: map[string]*string(kp.CipherData.MaterialDescription),
CiphertextBlob: key,
GrantTokens: []*string{},
})
if err != nil {
return nil, err
}
return out.Plaintext, nil
}
// GenerateCipherData makes a call to KMS to generate a data key, Upon making
// the call, it also sets the encrypted key.
func (kp *kmsKeyHandler) GenerateCipherData(keySize, ivSize int) (CipherData, error) {
out, err := kp.kms.GenerateDataKey(&kms.GenerateDataKeyInput{
EncryptionContext: kp.CipherData.MaterialDescription,
KeyId: kp.cmkID,
KeySpec: aws.String("AES_256"),
})
if err != nil {
return CipherData{}, err
}
iv := generateBytes(ivSize)
cd := CipherData{
Key: out.Plaintext,
IV: iv,
WrapAlgorithm: KMSWrap,
MaterialDescription: kp.CipherData.MaterialDescription,
EncryptedKey: out.CiphertextBlob,
}
return cd, nil
}

View File

@@ -0,0 +1,125 @@
package s3crypto
import (
"bytes"
"encoding/base64"
"encoding/hex"
"fmt"
"net/http"
"net/http/httptest"
"reflect"
"testing"
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/awstesting/unit"
"github.com/aws/aws-sdk-go/service/kms"
)
func TestBuildKMSEncryptHandler(t *testing.T) {
svc := kms.New(unit.Session)
handler := NewKMSKeyGenerator(svc, "testid")
if handler == nil {
t.Error("expected non-nil handler")
}
}
func TestBuildKMSEncryptHandlerWithMatDesc(t *testing.T) {
svc := kms.New(unit.Session)
handler := NewKMSKeyGeneratorWithMatDesc(svc, "testid", MaterialDescription{
"Testing": aws.String("123"),
})
if handler == nil {
t.Error("expected non-nil handler")
}
kmsHandler := handler.(*kmsKeyHandler)
expected := MaterialDescription{
"kms_cmk_id": aws.String("testid"),
"Testing": aws.String("123"),
}
if !reflect.DeepEqual(expected, kmsHandler.CipherData.MaterialDescription) {
t.Errorf("expected %v, but received %v", expected, kmsHandler.CipherData.MaterialDescription)
}
}
func TestKMSGenerateCipherData(t *testing.T) {
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
fmt.Fprintln(w, `{"CiphertextBlob":"AQEDAHhqBCCY1MSimw8gOGcUma79cn4ANvTtQyv9iuBdbcEF1QAAAH4wfAYJKoZIhvcNAQcGoG8wbQIBADBoBgkqhkiG9w0BBwEwHgYJYIZIAWUDBAEuMBEEDJ6IcN5E4wVbk38MNAIBEIA7oF1E3lS7FY9DkoxPc/UmJsEwHzL82zMqoLwXIvi8LQHr8If4Lv6zKqY8u0+JRgSVoqCvZDx3p8Cn6nM=","KeyId":"arn:aws:kms:us-west-2:042062605278:key/c80a5cdb-8d09-4f9f-89ee-df01b2e3870a","Plaintext":"6tmyz9JLBE2yIuU7iXpArqpDVle172WSmxjcO6GNT7E="}`)
}))
sess := unit.Session.Copy(&aws.Config{
MaxRetries: aws.Int(0),
Endpoint: aws.String(ts.URL),
DisableSSL: aws.Bool(true),
S3ForcePathStyle: aws.Bool(true),
Region: aws.String("us-west-2"),
})
svc := kms.New(sess)
handler := NewKMSKeyGenerator(svc, "testid")
keySize := 32
ivSize := 16
cd, err := handler.GenerateCipherData(keySize, ivSize)
if err != nil {
t.Errorf("expected no error, but received %v", err)
}
if keySize != len(cd.Key) {
t.Errorf("expected %d, but received %d", keySize, len(cd.Key))
}
if ivSize != len(cd.IV) {
t.Errorf("expected %d, but received %d", ivSize, len(cd.IV))
}
}
func TestKMSDecrypt(t *testing.T) {
key, _ := hex.DecodeString("31bdadd96698c204aa9ce1448ea94ae1fb4a9a0b3c9d773b51bb1822666b8f22")
keyB64 := base64.URLEncoding.EncodeToString(key)
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
fmt.Fprintln(w, fmt.Sprintf("%s%s%s", `{"KeyId":"test-key-id","Plaintext":"`, keyB64, `"}`))
}))
sess := unit.Session.Copy(&aws.Config{
MaxRetries: aws.Int(0),
Endpoint: aws.String(ts.URL),
DisableSSL: aws.Bool(true),
S3ForcePathStyle: aws.Bool(true),
Region: aws.String("us-west-2"),
})
handler, err := (kmsKeyHandler{kms: kms.New(sess)}).decryptHandler(Envelope{MatDesc: `{"kms_cmk_id":"test"}`})
if err != nil {
t.Errorf("expected no error, but received %v", err)
}
plaintextKey, err := handler.DecryptKey([]byte{1, 2, 3, 4})
if err != nil {
t.Errorf("expected no error, but received %v", err)
}
if !bytes.Equal(key, plaintextKey) {
t.Errorf("expected %v, but received %v", key, plaintextKey)
}
}
func TestKMSDecryptBadJSON(t *testing.T) {
key, _ := hex.DecodeString("31bdadd96698c204aa9ce1448ea94ae1fb4a9a0b3c9d773b51bb1822666b8f22")
keyB64 := base64.URLEncoding.EncodeToString(key)
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
fmt.Fprintln(w, fmt.Sprintf("%s%s%s", `{"KeyId":"test-key-id","Plaintext":"`, keyB64, `"}`))
}))
sess := unit.Session.Copy(&aws.Config{
MaxRetries: aws.Int(0),
Endpoint: aws.String(ts.URL),
DisableSSL: aws.Bool(true),
S3ForcePathStyle: aws.Bool(true),
Region: aws.String("us-west-2"),
})
_, err := (kmsKeyHandler{kms: kms.New(sess)}).decryptHandler(Envelope{MatDesc: `{"kms_cmk_id":"test"`})
if err == nil {
t.Errorf("expected error, but received none")
}
}

View File

@@ -0,0 +1,18 @@
package s3crypto
import (
"encoding/json"
)
// MaterialDescription is used to identify how and what master
// key has been used.
type MaterialDescription map[string]*string
func (md *MaterialDescription) encodeDescription() ([]byte, error) {
v, err := json.Marshal(&md)
return v, err
}
func (md *MaterialDescription) decodeDescription(b []byte) error {
return json.Unmarshal(b, &md)
}

View File

@@ -0,0 +1,35 @@
package s3crypto
import (
"reflect"
"testing"
"github.com/aws/aws-sdk-go/aws"
)
func TestEncodeMaterialDescription(t *testing.T) {
md := MaterialDescription{}
md["foo"] = aws.String("bar")
b, err := md.encodeDescription()
expected := `{"foo":"bar"}`
if err != nil {
t.Errorf("expected no error, but received %v", err)
}
if expected != string(b) {
t.Errorf("expected %s, but received %s", expected, string(b))
}
}
func TestDecodeMaterialDescription(t *testing.T) {
md := MaterialDescription{}
json := `{"foo":"bar"}`
err := md.decodeDescription([]byte(json))
expected := MaterialDescription{
"foo": aws.String("bar"),
}
if err != nil {
t.Errorf("expected no error, but received %v", err)
}
if !reflect.DeepEqual(expected, md) {
t.Error("expected material description to be equivalent, but received otherwise")
}
}

View File

@@ -0,0 +1,70 @@
package s3crypto_test
import (
"bytes"
"io"
"io/ioutil"
"github.com/aws/aws-sdk-go/service/s3/s3crypto"
)
type mockGenerator struct {
}
func (m mockGenerator) GenerateCipherData(keySize, ivSize int) (s3crypto.CipherData, error) {
cd := s3crypto.CipherData{
Key: make([]byte, keySize),
IV: make([]byte, ivSize),
}
return cd, nil
}
func (m mockGenerator) EncryptKey(key []byte) ([]byte, error) {
size := len(key)
b := bytes.Repeat([]byte{1}, size)
return b, nil
}
func (m mockGenerator) DecryptKey(key []byte) ([]byte, error) {
return make([]byte, 16), nil
}
type mockCipherBuilder struct {
generator s3crypto.CipherDataGenerator
}
func (builder mockCipherBuilder) ContentCipher() (s3crypto.ContentCipher, error) {
cd, err := builder.generator.GenerateCipherData(32, 16)
if err != nil {
return nil, err
}
return &mockContentCipher{cd}, nil
}
type mockContentCipher struct {
cd s3crypto.CipherData
}
func (cipher *mockContentCipher) GetCipherData() s3crypto.CipherData {
return cipher.cd
}
func (cipher *mockContentCipher) EncryptContents(src io.Reader) (io.Reader, error) {
b, err := ioutil.ReadAll(src)
if err != nil {
return nil, err
}
size := len(b)
b = bytes.Repeat([]byte{1}, size)
return bytes.NewReader(b), nil
}
func (cipher *mockContentCipher) DecryptContents(src io.ReadCloser) (io.ReadCloser, error) {
b, err := ioutil.ReadAll(src)
if err != nil {
return nil, err
}
size := len(b)
return ioutil.NopCloser(bytes.NewReader(make([]byte, size))), nil
}

View File

@@ -0,0 +1,35 @@
package s3crypto
// Padder handles padding of crypto data
type Padder interface {
// Pad will pad the byte array.
// The second parameter is NOT how many
// bytes to pad by, but how many bytes
// have been read prior to the padding.
// This allows for streamable padding.
Pad([]byte, int) ([]byte, error)
// Unpad will unpad the byte bytes. Unpad
// methods must be constant time.
Unpad([]byte) ([]byte, error)
// Name returns the name of the padder.
// This is used when decrypting on
// instantiating new padders.
Name() string
}
// NoPadder does not pad anything
var NoPadder = Padder(noPadder{})
type noPadder struct{}
func (padder noPadder) Pad(b []byte, n int) ([]byte, error) {
return b, nil
}
func (padder noPadder) Unpad(b []byte) ([]byte, error) {
return b, nil
}
func (padder noPadder) Name() string {
return "NoPadding"
}

View File

@@ -0,0 +1,80 @@
package s3crypto
// Copyright 2017 Amazon.com, Inc. or its affiliates. All Rights Reserved.
//
// Portions Licensed under the MIT License. Copyright (c) 2016 Carl Jackson
import (
"bytes"
"crypto/subtle"
"github.com/aws/aws-sdk-go/aws/awserr"
)
const (
pkcs7MaxPaddingSize = 255
)
type pkcs7Padder struct {
blockSize int
}
// NewPKCS7Padder follows the RFC 2315: https://www.ietf.org/rfc/rfc2315.txt
// PKCS7 padding is subject to side-channel attacks and timing attacks. For
// the most secure data, use an authenticated crypto algorithm.
func NewPKCS7Padder(blockSize int) Padder {
return pkcs7Padder{blockSize}
}
var errPKCS7Padding = awserr.New("InvalidPadding", "invalid padding", nil)
// Pad will pad the data relative to how many bytes have been read.
// Pad follows the PKCS7 standard.
func (padder pkcs7Padder) Pad(buf []byte, n int) ([]byte, error) {
if padder.blockSize < 1 || padder.blockSize > pkcs7MaxPaddingSize {
return nil, awserr.New("InvalidBlockSize", "block size must be between 1 and 255", nil)
}
size := padder.blockSize - (n % padder.blockSize)
pad := bytes.Repeat([]byte{byte(size)}, size)
buf = append(buf, pad...)
return buf, nil
}
// Unpad will unpad the correct amount of bytes based off
// of the PKCS7 standard
func (padder pkcs7Padder) Unpad(buf []byte) ([]byte, error) {
if len(buf) == 0 {
return nil, errPKCS7Padding
}
// Here be dragons. We're attempting to check the padding in constant
// time. The only piece of information here which is public is len(buf).
// This code is modeled loosely after tls1_cbc_remove_padding from
// OpenSSL.
padLen := buf[len(buf)-1]
toCheck := pkcs7MaxPaddingSize
good := 1
if toCheck > len(buf) {
toCheck = len(buf)
}
for i := 0; i < toCheck; i++ {
b := buf[len(buf)-1-i]
outOfRange := subtle.ConstantTimeLessOrEq(int(padLen), i)
equal := subtle.ConstantTimeByteEq(padLen, b)
good &= subtle.ConstantTimeSelect(outOfRange, 1, equal)
}
good &= subtle.ConstantTimeLessOrEq(1, int(padLen))
good &= subtle.ConstantTimeLessOrEq(int(padLen), len(buf))
if good != 1 {
return nil, errPKCS7Padding
}
return buf[:len(buf)-int(padLen)], nil
}
func (padder pkcs7Padder) Name() string {
return "PKCS7Padding"
}

View File

@@ -0,0 +1,57 @@
package s3crypto_test
import (
"bytes"
"fmt"
"testing"
"github.com/aws/aws-sdk-go/service/s3/s3crypto"
)
func padTest(size int, t *testing.T) {
padder := s3crypto.NewPKCS7Padder(size)
for i := 0; i < size; i++ {
input := make([]byte, i)
expected := append(input, bytes.Repeat([]byte{byte(size - i)}, size-i)...)
b, err := padder.Pad(input, len(input))
if err != nil {
t.Fatal("Expected error to be nil but received " + err.Error())
}
if len(b) != len(expected) {
t.Fatal(fmt.Sprintf("Case %d: data is not of the same length", i))
}
if bytes.Compare(b, expected) != 0 {
t.Fatal(fmt.Sprintf("Expected %v but got %v", expected, b))
}
}
}
func unpadTest(size int, t *testing.T) {
padder := s3crypto.NewPKCS7Padder(size)
for i := 0; i < size; i++ {
expected := make([]byte, i)
input := append(expected, bytes.Repeat([]byte{byte(size - i)}, size-i)...)
b, err := padder.Unpad(input)
if err != nil {
t.Fatal("Error received, was expecting nil: " + err.Error())
}
if len(b) != len(expected) {
t.Fatal(fmt.Sprintf("Case %d: data is not of the same length", i))
}
if bytes.Compare(b, expected) != 0 {
t.Fatal(fmt.Sprintf("Expected %v but got %v", expected, b))
}
}
}
func TestPKCS7Padding(t *testing.T) {
padTest(10, t)
padTest(16, t)
padTest(255, t)
}
func TestPKCS7Unpadding(t *testing.T) {
unpadTest(10, t)
unpadTest(16, t)
unpadTest(255, t)
}

View File

@@ -0,0 +1,145 @@
package s3crypto
import (
"bytes"
"encoding/json"
"io/ioutil"
"net/http"
"strings"
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/awserr"
"github.com/aws/aws-sdk-go/aws/request"
"github.com/aws/aws-sdk-go/service/s3"
)
// SaveStrategy is how the data's metadata wants to be saved
type SaveStrategy interface {
Save(Envelope, *request.Request) error
}
// S3SaveStrategy will save the metadata to a separate instruction file in S3
type S3SaveStrategy struct {
Client *s3.S3
InstructionFileSuffix string
}
// Save will save the envelope contents to s3.
func (strat S3SaveStrategy) Save(env Envelope, req *request.Request) error {
input := req.Params.(*s3.PutObjectInput)
b, err := json.Marshal(env)
if err != nil {
return err
}
instInput := s3.PutObjectInput{
Bucket: input.Bucket,
Body: bytes.NewReader(b),
}
if strat.InstructionFileSuffix == "" {
instInput.Key = aws.String(*input.Key + DefaultInstructionKeySuffix)
} else {
instInput.Key = aws.String(*input.Key + strat.InstructionFileSuffix)
}
_, err = strat.Client.PutObject(&instInput)
return err
}
// HeaderV2SaveStrategy will save the metadata of the crypto contents to the header of
// the object.
type HeaderV2SaveStrategy struct{}
// Save will save the envelope to the request's header.
func (strat HeaderV2SaveStrategy) Save(env Envelope, req *request.Request) error {
input := req.Params.(*s3.PutObjectInput)
if input.Metadata == nil {
input.Metadata = map[string]*string{}
}
input.Metadata[http.CanonicalHeaderKey(keyV2Header)] = &env.CipherKey
input.Metadata[http.CanonicalHeaderKey(ivHeader)] = &env.IV
input.Metadata[http.CanonicalHeaderKey(matDescHeader)] = &env.MatDesc
input.Metadata[http.CanonicalHeaderKey(wrapAlgorithmHeader)] = &env.WrapAlg
input.Metadata[http.CanonicalHeaderKey(cekAlgorithmHeader)] = &env.CEKAlg
input.Metadata[http.CanonicalHeaderKey(unencryptedMD5Header)] = &env.UnencryptedMD5
input.Metadata[http.CanonicalHeaderKey(unencryptedContentLengthHeader)] = &env.UnencryptedContentLen
if len(env.TagLen) > 0 {
input.Metadata[http.CanonicalHeaderKey(tagLengthHeader)] = &env.TagLen
}
return nil
}
// LoadStrategy ...
type LoadStrategy interface {
Load(*request.Request) (Envelope, error)
}
// S3LoadStrategy will load the instruction file from s3
type S3LoadStrategy struct {
Client *s3.S3
InstructionFileSuffix string
}
// Load from a given instruction file suffix
func (load S3LoadStrategy) Load(req *request.Request) (Envelope, error) {
env := Envelope{}
if load.InstructionFileSuffix == "" {
load.InstructionFileSuffix = DefaultInstructionKeySuffix
}
input := req.Params.(*s3.GetObjectInput)
out, err := load.Client.GetObject(&s3.GetObjectInput{
Key: aws.String(strings.Join([]string{*input.Key, load.InstructionFileSuffix}, "")),
Bucket: input.Bucket,
})
if err != nil {
return env, err
}
b, err := ioutil.ReadAll(out.Body)
if err != nil {
return env, err
}
err = json.Unmarshal(b, &env)
return env, err
}
// HeaderV2LoadStrategy will load the envelope from the metadata
type HeaderV2LoadStrategy struct{}
// Load from a given object's header
func (load HeaderV2LoadStrategy) Load(req *request.Request) (Envelope, error) {
env := Envelope{}
env.CipherKey = req.HTTPResponse.Header.Get(strings.Join([]string{metaHeader, keyV2Header}, "-"))
env.IV = req.HTTPResponse.Header.Get(strings.Join([]string{metaHeader, ivHeader}, "-"))
env.MatDesc = req.HTTPResponse.Header.Get(strings.Join([]string{metaHeader, matDescHeader}, "-"))
env.WrapAlg = req.HTTPResponse.Header.Get(strings.Join([]string{metaHeader, wrapAlgorithmHeader}, "-"))
env.CEKAlg = req.HTTPResponse.Header.Get(strings.Join([]string{metaHeader, cekAlgorithmHeader}, "-"))
env.TagLen = req.HTTPResponse.Header.Get(strings.Join([]string{metaHeader, tagLengthHeader}, "-"))
env.UnencryptedMD5 = req.HTTPResponse.Header.Get(strings.Join([]string{metaHeader, unencryptedMD5Header}, "-"))
env.UnencryptedContentLen = req.HTTPResponse.Header.Get(strings.Join([]string{metaHeader, unencryptedContentLengthHeader}, "-"))
return env, nil
}
type defaultV2LoadStrategy struct {
client *s3.S3
suffix string
}
func (load defaultV2LoadStrategy) Load(req *request.Request) (Envelope, error) {
if value := req.HTTPResponse.Header.Get(strings.Join([]string{metaHeader, keyV2Header}, "-")); value != "" {
strat := HeaderV2LoadStrategy{}
return strat.Load(req)
} else if value = req.HTTPResponse.Header.Get(strings.Join([]string{metaHeader, keyV1Header}, "-")); value != "" {
return Envelope{}, awserr.New("V1NotSupportedError", "The AWS SDK for Go does not support version 1", nil)
}
strat := S3LoadStrategy{
Client: load.client,
InstructionFileSuffix: load.suffix,
}
return strat.Load(req)
}

View File

@@ -0,0 +1,77 @@
package s3crypto_test
import (
"reflect"
"testing"
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/request"
"github.com/aws/aws-sdk-go/service/s3"
"github.com/aws/aws-sdk-go/service/s3/s3crypto"
)
func TestHeaderV2SaveStrategy(t *testing.T) {
cases := []struct {
env s3crypto.Envelope
expected map[string]*string
}{
{
s3crypto.Envelope{
CipherKey: "Foo",
IV: "Bar",
MatDesc: "{}",
WrapAlg: s3crypto.KMSWrap,
CEKAlg: s3crypto.AESGCMNoPadding,
TagLen: "128",
UnencryptedMD5: "hello",
UnencryptedContentLen: "0",
},
map[string]*string{
"X-Amz-Key-V2": aws.String("Foo"),
"X-Amz-Iv": aws.String("Bar"),
"X-Amz-Matdesc": aws.String("{}"),
"X-Amz-Wrap-Alg": aws.String(s3crypto.KMSWrap),
"X-Amz-Cek-Alg": aws.String(s3crypto.AESGCMNoPadding),
"X-Amz-Tag-Len": aws.String("128"),
"X-Amz-Unencrypted-Content-Md5": aws.String("hello"),
"X-Amz-Unencrypted-Content-Length": aws.String("0"),
},
},
{
s3crypto.Envelope{
CipherKey: "Foo",
IV: "Bar",
MatDesc: "{}",
WrapAlg: s3crypto.KMSWrap,
CEKAlg: s3crypto.AESGCMNoPadding,
UnencryptedMD5: "hello",
UnencryptedContentLen: "0",
},
map[string]*string{
"X-Amz-Key-V2": aws.String("Foo"),
"X-Amz-Iv": aws.String("Bar"),
"X-Amz-Matdesc": aws.String("{}"),
"X-Amz-Wrap-Alg": aws.String(s3crypto.KMSWrap),
"X-Amz-Cek-Alg": aws.String(s3crypto.AESGCMNoPadding),
"X-Amz-Unencrypted-Content-Md5": aws.String("hello"),
"X-Amz-Unencrypted-Content-Length": aws.String("0"),
},
},
}
for _, c := range cases {
params := &s3.PutObjectInput{}
req := &request.Request{
Params: params,
}
strat := s3crypto.HeaderV2SaveStrategy{}
err := strat.Save(c.env, req)
if err != nil {
t.Errorf("expected no error, but received %v", err)
}
if !reflect.DeepEqual(c.expected, params.Metadata) {
t.Errorf("expected %v, but received %v", c.expected, params.Metadata)
}
}
}

View File

@@ -0,0 +1,403 @@
// Code generated by private/model/cli/gen-api/main.go. DO NOT EDIT.
// Package s3iface provides an interface to enable mocking the Amazon Simple Storage Service service client
// for testing your code.
//
// It is important to note that this interface will have breaking changes
// when the service model is updated and adds new API operations, paginators,
// and waiters.
package s3iface
import (
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/request"
"github.com/aws/aws-sdk-go/service/s3"
)
// S3API provides an interface to enable mocking the
// s3.S3 service client's API operation,
// paginators, and waiters. This make unit testing your code that calls out
// to the SDK's service client's calls easier.
//
// The best way to use this interface is so the SDK's service client's calls
// can be stubbed out for unit testing your code with the SDK without needing
// to inject custom request handlers into the SDK's request pipeline.
//
// // myFunc uses an SDK service client to make a request to
// // Amazon Simple Storage Service.
// func myFunc(svc s3iface.S3API) bool {
// // Make svc.AbortMultipartUpload request
// }
//
// func main() {
// sess := session.New()
// svc := s3.New(sess)
//
// myFunc(svc)
// }
//
// In your _test.go file:
//
// // Define a mock struct to be used in your unit tests of myFunc.
// type mockS3Client struct {
// s3iface.S3API
// }
// func (m *mockS3Client) AbortMultipartUpload(input *s3.AbortMultipartUploadInput) (*s3.AbortMultipartUploadOutput, error) {
// // mock response/functionality
// }
//
// func TestMyFunc(t *testing.T) {
// // Setup Test
// mockSvc := &mockS3Client{}
//
// myfunc(mockSvc)
//
// // Verify myFunc's functionality
// }
//
// It is important to note that this interface will have breaking changes
// when the service model is updated and adds new API operations, paginators,
// and waiters. Its suggested to use the pattern above for testing, or using
// tooling to generate mocks to satisfy the interfaces.
type S3API interface {
AbortMultipartUpload(*s3.AbortMultipartUploadInput) (*s3.AbortMultipartUploadOutput, error)
AbortMultipartUploadWithContext(aws.Context, *s3.AbortMultipartUploadInput, ...request.Option) (*s3.AbortMultipartUploadOutput, error)
AbortMultipartUploadRequest(*s3.AbortMultipartUploadInput) (*request.Request, *s3.AbortMultipartUploadOutput)
CompleteMultipartUpload(*s3.CompleteMultipartUploadInput) (*s3.CompleteMultipartUploadOutput, error)
CompleteMultipartUploadWithContext(aws.Context, *s3.CompleteMultipartUploadInput, ...request.Option) (*s3.CompleteMultipartUploadOutput, error)
CompleteMultipartUploadRequest(*s3.CompleteMultipartUploadInput) (*request.Request, *s3.CompleteMultipartUploadOutput)
CopyObject(*s3.CopyObjectInput) (*s3.CopyObjectOutput, error)
CopyObjectWithContext(aws.Context, *s3.CopyObjectInput, ...request.Option) (*s3.CopyObjectOutput, error)
CopyObjectRequest(*s3.CopyObjectInput) (*request.Request, *s3.CopyObjectOutput)
CreateBucket(*s3.CreateBucketInput) (*s3.CreateBucketOutput, error)
CreateBucketWithContext(aws.Context, *s3.CreateBucketInput, ...request.Option) (*s3.CreateBucketOutput, error)
CreateBucketRequest(*s3.CreateBucketInput) (*request.Request, *s3.CreateBucketOutput)
CreateMultipartUpload(*s3.CreateMultipartUploadInput) (*s3.CreateMultipartUploadOutput, error)
CreateMultipartUploadWithContext(aws.Context, *s3.CreateMultipartUploadInput, ...request.Option) (*s3.CreateMultipartUploadOutput, error)
CreateMultipartUploadRequest(*s3.CreateMultipartUploadInput) (*request.Request, *s3.CreateMultipartUploadOutput)
DeleteBucket(*s3.DeleteBucketInput) (*s3.DeleteBucketOutput, error)
DeleteBucketWithContext(aws.Context, *s3.DeleteBucketInput, ...request.Option) (*s3.DeleteBucketOutput, error)
DeleteBucketRequest(*s3.DeleteBucketInput) (*request.Request, *s3.DeleteBucketOutput)
DeleteBucketAnalyticsConfiguration(*s3.DeleteBucketAnalyticsConfigurationInput) (*s3.DeleteBucketAnalyticsConfigurationOutput, error)
DeleteBucketAnalyticsConfigurationWithContext(aws.Context, *s3.DeleteBucketAnalyticsConfigurationInput, ...request.Option) (*s3.DeleteBucketAnalyticsConfigurationOutput, error)
DeleteBucketAnalyticsConfigurationRequest(*s3.DeleteBucketAnalyticsConfigurationInput) (*request.Request, *s3.DeleteBucketAnalyticsConfigurationOutput)
DeleteBucketCors(*s3.DeleteBucketCorsInput) (*s3.DeleteBucketCorsOutput, error)
DeleteBucketCorsWithContext(aws.Context, *s3.DeleteBucketCorsInput, ...request.Option) (*s3.DeleteBucketCorsOutput, error)
DeleteBucketCorsRequest(*s3.DeleteBucketCorsInput) (*request.Request, *s3.DeleteBucketCorsOutput)
DeleteBucketEncryption(*s3.DeleteBucketEncryptionInput) (*s3.DeleteBucketEncryptionOutput, error)
DeleteBucketEncryptionWithContext(aws.Context, *s3.DeleteBucketEncryptionInput, ...request.Option) (*s3.DeleteBucketEncryptionOutput, error)
DeleteBucketEncryptionRequest(*s3.DeleteBucketEncryptionInput) (*request.Request, *s3.DeleteBucketEncryptionOutput)
DeleteBucketInventoryConfiguration(*s3.DeleteBucketInventoryConfigurationInput) (*s3.DeleteBucketInventoryConfigurationOutput, error)
DeleteBucketInventoryConfigurationWithContext(aws.Context, *s3.DeleteBucketInventoryConfigurationInput, ...request.Option) (*s3.DeleteBucketInventoryConfigurationOutput, error)
DeleteBucketInventoryConfigurationRequest(*s3.DeleteBucketInventoryConfigurationInput) (*request.Request, *s3.DeleteBucketInventoryConfigurationOutput)
DeleteBucketLifecycle(*s3.DeleteBucketLifecycleInput) (*s3.DeleteBucketLifecycleOutput, error)
DeleteBucketLifecycleWithContext(aws.Context, *s3.DeleteBucketLifecycleInput, ...request.Option) (*s3.DeleteBucketLifecycleOutput, error)
DeleteBucketLifecycleRequest(*s3.DeleteBucketLifecycleInput) (*request.Request, *s3.DeleteBucketLifecycleOutput)
DeleteBucketMetricsConfiguration(*s3.DeleteBucketMetricsConfigurationInput) (*s3.DeleteBucketMetricsConfigurationOutput, error)
DeleteBucketMetricsConfigurationWithContext(aws.Context, *s3.DeleteBucketMetricsConfigurationInput, ...request.Option) (*s3.DeleteBucketMetricsConfigurationOutput, error)
DeleteBucketMetricsConfigurationRequest(*s3.DeleteBucketMetricsConfigurationInput) (*request.Request, *s3.DeleteBucketMetricsConfigurationOutput)
DeleteBucketPolicy(*s3.DeleteBucketPolicyInput) (*s3.DeleteBucketPolicyOutput, error)
DeleteBucketPolicyWithContext(aws.Context, *s3.DeleteBucketPolicyInput, ...request.Option) (*s3.DeleteBucketPolicyOutput, error)
DeleteBucketPolicyRequest(*s3.DeleteBucketPolicyInput) (*request.Request, *s3.DeleteBucketPolicyOutput)
DeleteBucketReplication(*s3.DeleteBucketReplicationInput) (*s3.DeleteBucketReplicationOutput, error)
DeleteBucketReplicationWithContext(aws.Context, *s3.DeleteBucketReplicationInput, ...request.Option) (*s3.DeleteBucketReplicationOutput, error)
DeleteBucketReplicationRequest(*s3.DeleteBucketReplicationInput) (*request.Request, *s3.DeleteBucketReplicationOutput)
DeleteBucketTagging(*s3.DeleteBucketTaggingInput) (*s3.DeleteBucketTaggingOutput, error)
DeleteBucketTaggingWithContext(aws.Context, *s3.DeleteBucketTaggingInput, ...request.Option) (*s3.DeleteBucketTaggingOutput, error)
DeleteBucketTaggingRequest(*s3.DeleteBucketTaggingInput) (*request.Request, *s3.DeleteBucketTaggingOutput)
DeleteBucketWebsite(*s3.DeleteBucketWebsiteInput) (*s3.DeleteBucketWebsiteOutput, error)
DeleteBucketWebsiteWithContext(aws.Context, *s3.DeleteBucketWebsiteInput, ...request.Option) (*s3.DeleteBucketWebsiteOutput, error)
DeleteBucketWebsiteRequest(*s3.DeleteBucketWebsiteInput) (*request.Request, *s3.DeleteBucketWebsiteOutput)
DeleteObject(*s3.DeleteObjectInput) (*s3.DeleteObjectOutput, error)
DeleteObjectWithContext(aws.Context, *s3.DeleteObjectInput, ...request.Option) (*s3.DeleteObjectOutput, error)
DeleteObjectRequest(*s3.DeleteObjectInput) (*request.Request, *s3.DeleteObjectOutput)
DeleteObjectTagging(*s3.DeleteObjectTaggingInput) (*s3.DeleteObjectTaggingOutput, error)
DeleteObjectTaggingWithContext(aws.Context, *s3.DeleteObjectTaggingInput, ...request.Option) (*s3.DeleteObjectTaggingOutput, error)
DeleteObjectTaggingRequest(*s3.DeleteObjectTaggingInput) (*request.Request, *s3.DeleteObjectTaggingOutput)
DeleteObjects(*s3.DeleteObjectsInput) (*s3.DeleteObjectsOutput, error)
DeleteObjectsWithContext(aws.Context, *s3.DeleteObjectsInput, ...request.Option) (*s3.DeleteObjectsOutput, error)
DeleteObjectsRequest(*s3.DeleteObjectsInput) (*request.Request, *s3.DeleteObjectsOutput)
GetBucketAccelerateConfiguration(*s3.GetBucketAccelerateConfigurationInput) (*s3.GetBucketAccelerateConfigurationOutput, error)
GetBucketAccelerateConfigurationWithContext(aws.Context, *s3.GetBucketAccelerateConfigurationInput, ...request.Option) (*s3.GetBucketAccelerateConfigurationOutput, error)
GetBucketAccelerateConfigurationRequest(*s3.GetBucketAccelerateConfigurationInput) (*request.Request, *s3.GetBucketAccelerateConfigurationOutput)
GetBucketAcl(*s3.GetBucketAclInput) (*s3.GetBucketAclOutput, error)
GetBucketAclWithContext(aws.Context, *s3.GetBucketAclInput, ...request.Option) (*s3.GetBucketAclOutput, error)
GetBucketAclRequest(*s3.GetBucketAclInput) (*request.Request, *s3.GetBucketAclOutput)
GetBucketAnalyticsConfiguration(*s3.GetBucketAnalyticsConfigurationInput) (*s3.GetBucketAnalyticsConfigurationOutput, error)
GetBucketAnalyticsConfigurationWithContext(aws.Context, *s3.GetBucketAnalyticsConfigurationInput, ...request.Option) (*s3.GetBucketAnalyticsConfigurationOutput, error)
GetBucketAnalyticsConfigurationRequest(*s3.GetBucketAnalyticsConfigurationInput) (*request.Request, *s3.GetBucketAnalyticsConfigurationOutput)
GetBucketCors(*s3.GetBucketCorsInput) (*s3.GetBucketCorsOutput, error)
GetBucketCorsWithContext(aws.Context, *s3.GetBucketCorsInput, ...request.Option) (*s3.GetBucketCorsOutput, error)
GetBucketCorsRequest(*s3.GetBucketCorsInput) (*request.Request, *s3.GetBucketCorsOutput)
GetBucketEncryption(*s3.GetBucketEncryptionInput) (*s3.GetBucketEncryptionOutput, error)
GetBucketEncryptionWithContext(aws.Context, *s3.GetBucketEncryptionInput, ...request.Option) (*s3.GetBucketEncryptionOutput, error)
GetBucketEncryptionRequest(*s3.GetBucketEncryptionInput) (*request.Request, *s3.GetBucketEncryptionOutput)
GetBucketInventoryConfiguration(*s3.GetBucketInventoryConfigurationInput) (*s3.GetBucketInventoryConfigurationOutput, error)
GetBucketInventoryConfigurationWithContext(aws.Context, *s3.GetBucketInventoryConfigurationInput, ...request.Option) (*s3.GetBucketInventoryConfigurationOutput, error)
GetBucketInventoryConfigurationRequest(*s3.GetBucketInventoryConfigurationInput) (*request.Request, *s3.GetBucketInventoryConfigurationOutput)
GetBucketLifecycle(*s3.GetBucketLifecycleInput) (*s3.GetBucketLifecycleOutput, error)
GetBucketLifecycleWithContext(aws.Context, *s3.GetBucketLifecycleInput, ...request.Option) (*s3.GetBucketLifecycleOutput, error)
GetBucketLifecycleRequest(*s3.GetBucketLifecycleInput) (*request.Request, *s3.GetBucketLifecycleOutput)
GetBucketLifecycleConfiguration(*s3.GetBucketLifecycleConfigurationInput) (*s3.GetBucketLifecycleConfigurationOutput, error)
GetBucketLifecycleConfigurationWithContext(aws.Context, *s3.GetBucketLifecycleConfigurationInput, ...request.Option) (*s3.GetBucketLifecycleConfigurationOutput, error)
GetBucketLifecycleConfigurationRequest(*s3.GetBucketLifecycleConfigurationInput) (*request.Request, *s3.GetBucketLifecycleConfigurationOutput)
GetBucketLocation(*s3.GetBucketLocationInput) (*s3.GetBucketLocationOutput, error)
GetBucketLocationWithContext(aws.Context, *s3.GetBucketLocationInput, ...request.Option) (*s3.GetBucketLocationOutput, error)
GetBucketLocationRequest(*s3.GetBucketLocationInput) (*request.Request, *s3.GetBucketLocationOutput)
GetBucketLogging(*s3.GetBucketLoggingInput) (*s3.GetBucketLoggingOutput, error)
GetBucketLoggingWithContext(aws.Context, *s3.GetBucketLoggingInput, ...request.Option) (*s3.GetBucketLoggingOutput, error)
GetBucketLoggingRequest(*s3.GetBucketLoggingInput) (*request.Request, *s3.GetBucketLoggingOutput)
GetBucketMetricsConfiguration(*s3.GetBucketMetricsConfigurationInput) (*s3.GetBucketMetricsConfigurationOutput, error)
GetBucketMetricsConfigurationWithContext(aws.Context, *s3.GetBucketMetricsConfigurationInput, ...request.Option) (*s3.GetBucketMetricsConfigurationOutput, error)
GetBucketMetricsConfigurationRequest(*s3.GetBucketMetricsConfigurationInput) (*request.Request, *s3.GetBucketMetricsConfigurationOutput)
GetBucketNotification(*s3.GetBucketNotificationConfigurationRequest) (*s3.NotificationConfigurationDeprecated, error)
GetBucketNotificationWithContext(aws.Context, *s3.GetBucketNotificationConfigurationRequest, ...request.Option) (*s3.NotificationConfigurationDeprecated, error)
GetBucketNotificationRequest(*s3.GetBucketNotificationConfigurationRequest) (*request.Request, *s3.NotificationConfigurationDeprecated)
GetBucketNotificationConfiguration(*s3.GetBucketNotificationConfigurationRequest) (*s3.NotificationConfiguration, error)
GetBucketNotificationConfigurationWithContext(aws.Context, *s3.GetBucketNotificationConfigurationRequest, ...request.Option) (*s3.NotificationConfiguration, error)
GetBucketNotificationConfigurationRequest(*s3.GetBucketNotificationConfigurationRequest) (*request.Request, *s3.NotificationConfiguration)
GetBucketPolicy(*s3.GetBucketPolicyInput) (*s3.GetBucketPolicyOutput, error)
GetBucketPolicyWithContext(aws.Context, *s3.GetBucketPolicyInput, ...request.Option) (*s3.GetBucketPolicyOutput, error)
GetBucketPolicyRequest(*s3.GetBucketPolicyInput) (*request.Request, *s3.GetBucketPolicyOutput)
GetBucketReplication(*s3.GetBucketReplicationInput) (*s3.GetBucketReplicationOutput, error)
GetBucketReplicationWithContext(aws.Context, *s3.GetBucketReplicationInput, ...request.Option) (*s3.GetBucketReplicationOutput, error)
GetBucketReplicationRequest(*s3.GetBucketReplicationInput) (*request.Request, *s3.GetBucketReplicationOutput)
GetBucketRequestPayment(*s3.GetBucketRequestPaymentInput) (*s3.GetBucketRequestPaymentOutput, error)
GetBucketRequestPaymentWithContext(aws.Context, *s3.GetBucketRequestPaymentInput, ...request.Option) (*s3.GetBucketRequestPaymentOutput, error)
GetBucketRequestPaymentRequest(*s3.GetBucketRequestPaymentInput) (*request.Request, *s3.GetBucketRequestPaymentOutput)
GetBucketTagging(*s3.GetBucketTaggingInput) (*s3.GetBucketTaggingOutput, error)
GetBucketTaggingWithContext(aws.Context, *s3.GetBucketTaggingInput, ...request.Option) (*s3.GetBucketTaggingOutput, error)
GetBucketTaggingRequest(*s3.GetBucketTaggingInput) (*request.Request, *s3.GetBucketTaggingOutput)
GetBucketVersioning(*s3.GetBucketVersioningInput) (*s3.GetBucketVersioningOutput, error)
GetBucketVersioningWithContext(aws.Context, *s3.GetBucketVersioningInput, ...request.Option) (*s3.GetBucketVersioningOutput, error)
GetBucketVersioningRequest(*s3.GetBucketVersioningInput) (*request.Request, *s3.GetBucketVersioningOutput)
GetBucketWebsite(*s3.GetBucketWebsiteInput) (*s3.GetBucketWebsiteOutput, error)
GetBucketWebsiteWithContext(aws.Context, *s3.GetBucketWebsiteInput, ...request.Option) (*s3.GetBucketWebsiteOutput, error)
GetBucketWebsiteRequest(*s3.GetBucketWebsiteInput) (*request.Request, *s3.GetBucketWebsiteOutput)
GetObject(*s3.GetObjectInput) (*s3.GetObjectOutput, error)
GetObjectWithContext(aws.Context, *s3.GetObjectInput, ...request.Option) (*s3.GetObjectOutput, error)
GetObjectRequest(*s3.GetObjectInput) (*request.Request, *s3.GetObjectOutput)
GetObjectAcl(*s3.GetObjectAclInput) (*s3.GetObjectAclOutput, error)
GetObjectAclWithContext(aws.Context, *s3.GetObjectAclInput, ...request.Option) (*s3.GetObjectAclOutput, error)
GetObjectAclRequest(*s3.GetObjectAclInput) (*request.Request, *s3.GetObjectAclOutput)
GetObjectTagging(*s3.GetObjectTaggingInput) (*s3.GetObjectTaggingOutput, error)
GetObjectTaggingWithContext(aws.Context, *s3.GetObjectTaggingInput, ...request.Option) (*s3.GetObjectTaggingOutput, error)
GetObjectTaggingRequest(*s3.GetObjectTaggingInput) (*request.Request, *s3.GetObjectTaggingOutput)
GetObjectTorrent(*s3.GetObjectTorrentInput) (*s3.GetObjectTorrentOutput, error)
GetObjectTorrentWithContext(aws.Context, *s3.GetObjectTorrentInput, ...request.Option) (*s3.GetObjectTorrentOutput, error)
GetObjectTorrentRequest(*s3.GetObjectTorrentInput) (*request.Request, *s3.GetObjectTorrentOutput)
HeadBucket(*s3.HeadBucketInput) (*s3.HeadBucketOutput, error)
HeadBucketWithContext(aws.Context, *s3.HeadBucketInput, ...request.Option) (*s3.HeadBucketOutput, error)
HeadBucketRequest(*s3.HeadBucketInput) (*request.Request, *s3.HeadBucketOutput)
HeadObject(*s3.HeadObjectInput) (*s3.HeadObjectOutput, error)
HeadObjectWithContext(aws.Context, *s3.HeadObjectInput, ...request.Option) (*s3.HeadObjectOutput, error)
HeadObjectRequest(*s3.HeadObjectInput) (*request.Request, *s3.HeadObjectOutput)
ListBucketAnalyticsConfigurations(*s3.ListBucketAnalyticsConfigurationsInput) (*s3.ListBucketAnalyticsConfigurationsOutput, error)
ListBucketAnalyticsConfigurationsWithContext(aws.Context, *s3.ListBucketAnalyticsConfigurationsInput, ...request.Option) (*s3.ListBucketAnalyticsConfigurationsOutput, error)
ListBucketAnalyticsConfigurationsRequest(*s3.ListBucketAnalyticsConfigurationsInput) (*request.Request, *s3.ListBucketAnalyticsConfigurationsOutput)
ListBucketInventoryConfigurations(*s3.ListBucketInventoryConfigurationsInput) (*s3.ListBucketInventoryConfigurationsOutput, error)
ListBucketInventoryConfigurationsWithContext(aws.Context, *s3.ListBucketInventoryConfigurationsInput, ...request.Option) (*s3.ListBucketInventoryConfigurationsOutput, error)
ListBucketInventoryConfigurationsRequest(*s3.ListBucketInventoryConfigurationsInput) (*request.Request, *s3.ListBucketInventoryConfigurationsOutput)
ListBucketMetricsConfigurations(*s3.ListBucketMetricsConfigurationsInput) (*s3.ListBucketMetricsConfigurationsOutput, error)
ListBucketMetricsConfigurationsWithContext(aws.Context, *s3.ListBucketMetricsConfigurationsInput, ...request.Option) (*s3.ListBucketMetricsConfigurationsOutput, error)
ListBucketMetricsConfigurationsRequest(*s3.ListBucketMetricsConfigurationsInput) (*request.Request, *s3.ListBucketMetricsConfigurationsOutput)
ListBuckets(*s3.ListBucketsInput) (*s3.ListBucketsOutput, error)
ListBucketsWithContext(aws.Context, *s3.ListBucketsInput, ...request.Option) (*s3.ListBucketsOutput, error)
ListBucketsRequest(*s3.ListBucketsInput) (*request.Request, *s3.ListBucketsOutput)
ListMultipartUploads(*s3.ListMultipartUploadsInput) (*s3.ListMultipartUploadsOutput, error)
ListMultipartUploadsWithContext(aws.Context, *s3.ListMultipartUploadsInput, ...request.Option) (*s3.ListMultipartUploadsOutput, error)
ListMultipartUploadsRequest(*s3.ListMultipartUploadsInput) (*request.Request, *s3.ListMultipartUploadsOutput)
ListMultipartUploadsPages(*s3.ListMultipartUploadsInput, func(*s3.ListMultipartUploadsOutput, bool) bool) error
ListMultipartUploadsPagesWithContext(aws.Context, *s3.ListMultipartUploadsInput, func(*s3.ListMultipartUploadsOutput, bool) bool, ...request.Option) error
ListObjectVersions(*s3.ListObjectVersionsInput) (*s3.ListObjectVersionsOutput, error)
ListObjectVersionsWithContext(aws.Context, *s3.ListObjectVersionsInput, ...request.Option) (*s3.ListObjectVersionsOutput, error)
ListObjectVersionsRequest(*s3.ListObjectVersionsInput) (*request.Request, *s3.ListObjectVersionsOutput)
ListObjectVersionsPages(*s3.ListObjectVersionsInput, func(*s3.ListObjectVersionsOutput, bool) bool) error
ListObjectVersionsPagesWithContext(aws.Context, *s3.ListObjectVersionsInput, func(*s3.ListObjectVersionsOutput, bool) bool, ...request.Option) error
ListObjects(*s3.ListObjectsInput) (*s3.ListObjectsOutput, error)
ListObjectsWithContext(aws.Context, *s3.ListObjectsInput, ...request.Option) (*s3.ListObjectsOutput, error)
ListObjectsRequest(*s3.ListObjectsInput) (*request.Request, *s3.ListObjectsOutput)
ListObjectsPages(*s3.ListObjectsInput, func(*s3.ListObjectsOutput, bool) bool) error
ListObjectsPagesWithContext(aws.Context, *s3.ListObjectsInput, func(*s3.ListObjectsOutput, bool) bool, ...request.Option) error
ListObjectsV2(*s3.ListObjectsV2Input) (*s3.ListObjectsV2Output, error)
ListObjectsV2WithContext(aws.Context, *s3.ListObjectsV2Input, ...request.Option) (*s3.ListObjectsV2Output, error)
ListObjectsV2Request(*s3.ListObjectsV2Input) (*request.Request, *s3.ListObjectsV2Output)
ListObjectsV2Pages(*s3.ListObjectsV2Input, func(*s3.ListObjectsV2Output, bool) bool) error
ListObjectsV2PagesWithContext(aws.Context, *s3.ListObjectsV2Input, func(*s3.ListObjectsV2Output, bool) bool, ...request.Option) error
ListParts(*s3.ListPartsInput) (*s3.ListPartsOutput, error)
ListPartsWithContext(aws.Context, *s3.ListPartsInput, ...request.Option) (*s3.ListPartsOutput, error)
ListPartsRequest(*s3.ListPartsInput) (*request.Request, *s3.ListPartsOutput)
ListPartsPages(*s3.ListPartsInput, func(*s3.ListPartsOutput, bool) bool) error
ListPartsPagesWithContext(aws.Context, *s3.ListPartsInput, func(*s3.ListPartsOutput, bool) bool, ...request.Option) error
PutBucketAccelerateConfiguration(*s3.PutBucketAccelerateConfigurationInput) (*s3.PutBucketAccelerateConfigurationOutput, error)
PutBucketAccelerateConfigurationWithContext(aws.Context, *s3.PutBucketAccelerateConfigurationInput, ...request.Option) (*s3.PutBucketAccelerateConfigurationOutput, error)
PutBucketAccelerateConfigurationRequest(*s3.PutBucketAccelerateConfigurationInput) (*request.Request, *s3.PutBucketAccelerateConfigurationOutput)
PutBucketAcl(*s3.PutBucketAclInput) (*s3.PutBucketAclOutput, error)
PutBucketAclWithContext(aws.Context, *s3.PutBucketAclInput, ...request.Option) (*s3.PutBucketAclOutput, error)
PutBucketAclRequest(*s3.PutBucketAclInput) (*request.Request, *s3.PutBucketAclOutput)
PutBucketAnalyticsConfiguration(*s3.PutBucketAnalyticsConfigurationInput) (*s3.PutBucketAnalyticsConfigurationOutput, error)
PutBucketAnalyticsConfigurationWithContext(aws.Context, *s3.PutBucketAnalyticsConfigurationInput, ...request.Option) (*s3.PutBucketAnalyticsConfigurationOutput, error)
PutBucketAnalyticsConfigurationRequest(*s3.PutBucketAnalyticsConfigurationInput) (*request.Request, *s3.PutBucketAnalyticsConfigurationOutput)
PutBucketCors(*s3.PutBucketCorsInput) (*s3.PutBucketCorsOutput, error)
PutBucketCorsWithContext(aws.Context, *s3.PutBucketCorsInput, ...request.Option) (*s3.PutBucketCorsOutput, error)
PutBucketCorsRequest(*s3.PutBucketCorsInput) (*request.Request, *s3.PutBucketCorsOutput)
PutBucketEncryption(*s3.PutBucketEncryptionInput) (*s3.PutBucketEncryptionOutput, error)
PutBucketEncryptionWithContext(aws.Context, *s3.PutBucketEncryptionInput, ...request.Option) (*s3.PutBucketEncryptionOutput, error)
PutBucketEncryptionRequest(*s3.PutBucketEncryptionInput) (*request.Request, *s3.PutBucketEncryptionOutput)
PutBucketInventoryConfiguration(*s3.PutBucketInventoryConfigurationInput) (*s3.PutBucketInventoryConfigurationOutput, error)
PutBucketInventoryConfigurationWithContext(aws.Context, *s3.PutBucketInventoryConfigurationInput, ...request.Option) (*s3.PutBucketInventoryConfigurationOutput, error)
PutBucketInventoryConfigurationRequest(*s3.PutBucketInventoryConfigurationInput) (*request.Request, *s3.PutBucketInventoryConfigurationOutput)
PutBucketLifecycle(*s3.PutBucketLifecycleInput) (*s3.PutBucketLifecycleOutput, error)
PutBucketLifecycleWithContext(aws.Context, *s3.PutBucketLifecycleInput, ...request.Option) (*s3.PutBucketLifecycleOutput, error)
PutBucketLifecycleRequest(*s3.PutBucketLifecycleInput) (*request.Request, *s3.PutBucketLifecycleOutput)
PutBucketLifecycleConfiguration(*s3.PutBucketLifecycleConfigurationInput) (*s3.PutBucketLifecycleConfigurationOutput, error)
PutBucketLifecycleConfigurationWithContext(aws.Context, *s3.PutBucketLifecycleConfigurationInput, ...request.Option) (*s3.PutBucketLifecycleConfigurationOutput, error)
PutBucketLifecycleConfigurationRequest(*s3.PutBucketLifecycleConfigurationInput) (*request.Request, *s3.PutBucketLifecycleConfigurationOutput)
PutBucketLogging(*s3.PutBucketLoggingInput) (*s3.PutBucketLoggingOutput, error)
PutBucketLoggingWithContext(aws.Context, *s3.PutBucketLoggingInput, ...request.Option) (*s3.PutBucketLoggingOutput, error)
PutBucketLoggingRequest(*s3.PutBucketLoggingInput) (*request.Request, *s3.PutBucketLoggingOutput)
PutBucketMetricsConfiguration(*s3.PutBucketMetricsConfigurationInput) (*s3.PutBucketMetricsConfigurationOutput, error)
PutBucketMetricsConfigurationWithContext(aws.Context, *s3.PutBucketMetricsConfigurationInput, ...request.Option) (*s3.PutBucketMetricsConfigurationOutput, error)
PutBucketMetricsConfigurationRequest(*s3.PutBucketMetricsConfigurationInput) (*request.Request, *s3.PutBucketMetricsConfigurationOutput)
PutBucketNotification(*s3.PutBucketNotificationInput) (*s3.PutBucketNotificationOutput, error)
PutBucketNotificationWithContext(aws.Context, *s3.PutBucketNotificationInput, ...request.Option) (*s3.PutBucketNotificationOutput, error)
PutBucketNotificationRequest(*s3.PutBucketNotificationInput) (*request.Request, *s3.PutBucketNotificationOutput)
PutBucketNotificationConfiguration(*s3.PutBucketNotificationConfigurationInput) (*s3.PutBucketNotificationConfigurationOutput, error)
PutBucketNotificationConfigurationWithContext(aws.Context, *s3.PutBucketNotificationConfigurationInput, ...request.Option) (*s3.PutBucketNotificationConfigurationOutput, error)
PutBucketNotificationConfigurationRequest(*s3.PutBucketNotificationConfigurationInput) (*request.Request, *s3.PutBucketNotificationConfigurationOutput)
PutBucketPolicy(*s3.PutBucketPolicyInput) (*s3.PutBucketPolicyOutput, error)
PutBucketPolicyWithContext(aws.Context, *s3.PutBucketPolicyInput, ...request.Option) (*s3.PutBucketPolicyOutput, error)
PutBucketPolicyRequest(*s3.PutBucketPolicyInput) (*request.Request, *s3.PutBucketPolicyOutput)
PutBucketReplication(*s3.PutBucketReplicationInput) (*s3.PutBucketReplicationOutput, error)
PutBucketReplicationWithContext(aws.Context, *s3.PutBucketReplicationInput, ...request.Option) (*s3.PutBucketReplicationOutput, error)
PutBucketReplicationRequest(*s3.PutBucketReplicationInput) (*request.Request, *s3.PutBucketReplicationOutput)
PutBucketRequestPayment(*s3.PutBucketRequestPaymentInput) (*s3.PutBucketRequestPaymentOutput, error)
PutBucketRequestPaymentWithContext(aws.Context, *s3.PutBucketRequestPaymentInput, ...request.Option) (*s3.PutBucketRequestPaymentOutput, error)
PutBucketRequestPaymentRequest(*s3.PutBucketRequestPaymentInput) (*request.Request, *s3.PutBucketRequestPaymentOutput)
PutBucketTagging(*s3.PutBucketTaggingInput) (*s3.PutBucketTaggingOutput, error)
PutBucketTaggingWithContext(aws.Context, *s3.PutBucketTaggingInput, ...request.Option) (*s3.PutBucketTaggingOutput, error)
PutBucketTaggingRequest(*s3.PutBucketTaggingInput) (*request.Request, *s3.PutBucketTaggingOutput)
PutBucketVersioning(*s3.PutBucketVersioningInput) (*s3.PutBucketVersioningOutput, error)
PutBucketVersioningWithContext(aws.Context, *s3.PutBucketVersioningInput, ...request.Option) (*s3.PutBucketVersioningOutput, error)
PutBucketVersioningRequest(*s3.PutBucketVersioningInput) (*request.Request, *s3.PutBucketVersioningOutput)
PutBucketWebsite(*s3.PutBucketWebsiteInput) (*s3.PutBucketWebsiteOutput, error)
PutBucketWebsiteWithContext(aws.Context, *s3.PutBucketWebsiteInput, ...request.Option) (*s3.PutBucketWebsiteOutput, error)
PutBucketWebsiteRequest(*s3.PutBucketWebsiteInput) (*request.Request, *s3.PutBucketWebsiteOutput)
PutObject(*s3.PutObjectInput) (*s3.PutObjectOutput, error)
PutObjectWithContext(aws.Context, *s3.PutObjectInput, ...request.Option) (*s3.PutObjectOutput, error)
PutObjectRequest(*s3.PutObjectInput) (*request.Request, *s3.PutObjectOutput)
PutObjectAcl(*s3.PutObjectAclInput) (*s3.PutObjectAclOutput, error)
PutObjectAclWithContext(aws.Context, *s3.PutObjectAclInput, ...request.Option) (*s3.PutObjectAclOutput, error)
PutObjectAclRequest(*s3.PutObjectAclInput) (*request.Request, *s3.PutObjectAclOutput)
PutObjectTagging(*s3.PutObjectTaggingInput) (*s3.PutObjectTaggingOutput, error)
PutObjectTaggingWithContext(aws.Context, *s3.PutObjectTaggingInput, ...request.Option) (*s3.PutObjectTaggingOutput, error)
PutObjectTaggingRequest(*s3.PutObjectTaggingInput) (*request.Request, *s3.PutObjectTaggingOutput)
RestoreObject(*s3.RestoreObjectInput) (*s3.RestoreObjectOutput, error)
RestoreObjectWithContext(aws.Context, *s3.RestoreObjectInput, ...request.Option) (*s3.RestoreObjectOutput, error)
RestoreObjectRequest(*s3.RestoreObjectInput) (*request.Request, *s3.RestoreObjectOutput)
SelectObjectContent(*s3.SelectObjectContentInput) (*s3.SelectObjectContentOutput, error)
SelectObjectContentWithContext(aws.Context, *s3.SelectObjectContentInput, ...request.Option) (*s3.SelectObjectContentOutput, error)
SelectObjectContentRequest(*s3.SelectObjectContentInput) (*request.Request, *s3.SelectObjectContentOutput)
UploadPart(*s3.UploadPartInput) (*s3.UploadPartOutput, error)
UploadPartWithContext(aws.Context, *s3.UploadPartInput, ...request.Option) (*s3.UploadPartOutput, error)
UploadPartRequest(*s3.UploadPartInput) (*request.Request, *s3.UploadPartOutput)
UploadPartCopy(*s3.UploadPartCopyInput) (*s3.UploadPartCopyOutput, error)
UploadPartCopyWithContext(aws.Context, *s3.UploadPartCopyInput, ...request.Option) (*s3.UploadPartCopyOutput, error)
UploadPartCopyRequest(*s3.UploadPartCopyInput) (*request.Request, *s3.UploadPartCopyOutput)
WaitUntilBucketExists(*s3.HeadBucketInput) error
WaitUntilBucketExistsWithContext(aws.Context, *s3.HeadBucketInput, ...request.WaiterOption) error
WaitUntilBucketNotExists(*s3.HeadBucketInput) error
WaitUntilBucketNotExistsWithContext(aws.Context, *s3.HeadBucketInput, ...request.WaiterOption) error
WaitUntilObjectExists(*s3.HeadObjectInput) error
WaitUntilObjectExistsWithContext(aws.Context, *s3.HeadObjectInput, ...request.WaiterOption) error
WaitUntilObjectNotExists(*s3.HeadObjectInput) error
WaitUntilObjectNotExistsWithContext(aws.Context, *s3.HeadObjectInput, ...request.WaiterOption) error
}
var _ S3API = (*s3.S3)(nil)

View File

@@ -0,0 +1,529 @@
package s3manager
import (
"bytes"
"fmt"
"io"
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/awserr"
"github.com/aws/aws-sdk-go/aws/client"
"github.com/aws/aws-sdk-go/aws/request"
"github.com/aws/aws-sdk-go/service/s3"
"github.com/aws/aws-sdk-go/service/s3/s3iface"
)
const (
// DefaultBatchSize is the batch size we initialize when constructing a batch delete client.
// This value is used when calling DeleteObjects. This represents how many objects to delete
// per DeleteObjects call.
DefaultBatchSize = 100
)
// BatchError will contain the key and bucket of the object that failed to
// either upload or download.
type BatchError struct {
Errors Errors
code string
message string
}
// Errors is a typed alias for a slice of errors to satisfy the error
// interface.
type Errors []Error
func (errs Errors) Error() string {
buf := bytes.NewBuffer(nil)
for i, err := range errs {
buf.WriteString(err.Error())
if i+1 < len(errs) {
buf.WriteString("\n")
}
}
return buf.String()
}
// Error will contain the original error, bucket, and key of the operation that failed
// during batch operations.
type Error struct {
OrigErr error
Bucket *string
Key *string
}
func newError(err error, bucket, key *string) Error {
return Error{
err,
bucket,
key,
}
}
func (err *Error) Error() string {
origErr := ""
if err.OrigErr != nil {
origErr = ":\n" + err.OrigErr.Error()
}
return fmt.Sprintf("failed to perform batch operation on %q to %q%s",
aws.StringValue(err.Key),
aws.StringValue(err.Bucket),
origErr,
)
}
// NewBatchError will return a BatchError that satisfies the awserr.Error interface.
func NewBatchError(code, message string, err []Error) awserr.Error {
return &BatchError{
Errors: err,
code: code,
message: message,
}
}
// Code will return the code associated with the batch error.
func (err *BatchError) Code() string {
return err.code
}
// Message will return the message associated with the batch error.
func (err *BatchError) Message() string {
return err.message
}
func (err *BatchError) Error() string {
return awserr.SprintError(err.Code(), err.Message(), "", err.Errors)
}
// OrigErr will return the original error. Which, in this case, will always be nil
// for batched operations.
func (err *BatchError) OrigErr() error {
return err.Errors
}
// BatchDeleteIterator is an interface that uses the scanner pattern to
// iterate through what needs to be deleted.
type BatchDeleteIterator interface {
Next() bool
Err() error
DeleteObject() BatchDeleteObject
}
// DeleteListIterator is an alternative iterator for the BatchDelete client. This will
// iterate through a list of objects and delete the objects.
//
// Example:
// iter := &s3manager.DeleteListIterator{
// Client: svc,
// Input: &s3.ListObjectsInput{
// Bucket: aws.String("bucket"),
// MaxKeys: aws.Int64(5),
// },
// Paginator: request.Pagination{
// NewRequest: func() (*request.Request, error) {
// var inCpy *ListObjectsInput
// if input != nil {
// tmp := *input
// inCpy = &tmp
// }
// req, _ := c.ListObjectsRequest(inCpy)
// return req, nil
// },
// },
// }
//
// batcher := s3manager.NewBatchDeleteWithClient(svc)
// if err := batcher.Delete(aws.BackgroundContext(), iter); err != nil {
// return err
// }
type DeleteListIterator struct {
Bucket *string
Paginator request.Pagination
objects []*s3.Object
}
// NewDeleteListIterator will return a new DeleteListIterator.
func NewDeleteListIterator(svc s3iface.S3API, input *s3.ListObjectsInput, opts ...func(*DeleteListIterator)) BatchDeleteIterator {
iter := &DeleteListIterator{
Bucket: input.Bucket,
Paginator: request.Pagination{
NewRequest: func() (*request.Request, error) {
var inCpy *s3.ListObjectsInput
if input != nil {
tmp := *input
inCpy = &tmp
}
req, _ := svc.ListObjectsRequest(inCpy)
return req, nil
},
},
}
for _, opt := range opts {
opt(iter)
}
return iter
}
// Next will use the S3API client to iterate through a list of objects.
func (iter *DeleteListIterator) Next() bool {
if len(iter.objects) > 0 {
iter.objects = iter.objects[1:]
}
if len(iter.objects) == 0 && iter.Paginator.Next() {
iter.objects = iter.Paginator.Page().(*s3.ListObjectsOutput).Contents
}
return len(iter.objects) > 0
}
// Err will return the last known error from Next.
func (iter *DeleteListIterator) Err() error {
return iter.Paginator.Err()
}
// DeleteObject will return the current object to be deleted.
func (iter *DeleteListIterator) DeleteObject() BatchDeleteObject {
return BatchDeleteObject{
Object: &s3.DeleteObjectInput{
Bucket: iter.Bucket,
Key: iter.objects[0].Key,
},
}
}
// BatchDelete will use the s3 package's service client to perform a batch
// delete.
type BatchDelete struct {
Client s3iface.S3API
BatchSize int
}
// NewBatchDeleteWithClient will return a new delete client that can delete a batched amount of
// objects.
//
// Example:
// batcher := s3manager.NewBatchDeleteWithClient(client, size)
//
// objects := []BatchDeleteObject{
// {
// Object: &s3.DeleteObjectInput {
// Key: aws.String("key"),
// Bucket: aws.String("bucket"),
// },
// },
// }
//
// if err := batcher.Delete(aws.BackgroundContext(), &s3manager.DeleteObjectsIterator{
// Objects: objects,
// }); err != nil {
// return err
// }
func NewBatchDeleteWithClient(client s3iface.S3API, options ...func(*BatchDelete)) *BatchDelete {
svc := &BatchDelete{
Client: client,
BatchSize: DefaultBatchSize,
}
for _, opt := range options {
opt(svc)
}
return svc
}
// NewBatchDelete will return a new delete client that can delete a batched amount of
// objects.
//
// Example:
// batcher := s3manager.NewBatchDelete(sess, size)
//
// objects := []BatchDeleteObject{
// {
// Object: &s3.DeleteObjectInput {
// Key: aws.String("key"),
// Bucket: aws.String("bucket"),
// },
// },
// }
//
// if err := batcher.Delete(aws.BackgroundContext(), &s3manager.DeleteObjectsIterator{
// Objects: objects,
// }); err != nil {
// return err
// }
func NewBatchDelete(c client.ConfigProvider, options ...func(*BatchDelete)) *BatchDelete {
client := s3.New(c)
return NewBatchDeleteWithClient(client, options...)
}
// BatchDeleteObject is a wrapper object for calling the batch delete operation.
type BatchDeleteObject struct {
Object *s3.DeleteObjectInput
// After will run after each iteration during the batch process. This function will
// be executed whether or not the request was successful.
After func() error
}
// DeleteObjectsIterator is an interface that uses the scanner pattern to iterate
// through a series of objects to be deleted.
type DeleteObjectsIterator struct {
Objects []BatchDeleteObject
index int
inc bool
}
// Next will increment the default iterator's index and and ensure that there
// is another object to iterator to.
func (iter *DeleteObjectsIterator) Next() bool {
if iter.inc {
iter.index++
} else {
iter.inc = true
}
return iter.index < len(iter.Objects)
}
// Err will return an error. Since this is just used to satisfy the BatchDeleteIterator interface
// this will only return nil.
func (iter *DeleteObjectsIterator) Err() error {
return nil
}
// DeleteObject will return the BatchDeleteObject at the current batched index.
func (iter *DeleteObjectsIterator) DeleteObject() BatchDeleteObject {
object := iter.Objects[iter.index]
return object
}
// Delete will use the iterator to queue up objects that need to be deleted.
// Once the batch size is met, this will call the deleteBatch function.
func (d *BatchDelete) Delete(ctx aws.Context, iter BatchDeleteIterator) error {
var errs []Error
objects := []BatchDeleteObject{}
var input *s3.DeleteObjectsInput
for iter.Next() {
o := iter.DeleteObject()
if input == nil {
input = initDeleteObjectsInput(o.Object)
}
parity := hasParity(input, o)
if parity {
input.Delete.Objects = append(input.Delete.Objects, &s3.ObjectIdentifier{
Key: o.Object.Key,
VersionId: o.Object.VersionId,
})
objects = append(objects, o)
}
if len(input.Delete.Objects) == d.BatchSize || !parity {
if err := deleteBatch(ctx, d, input, objects); err != nil {
errs = append(errs, err...)
}
objects = objects[:0]
input = nil
if !parity {
objects = append(objects, o)
input = initDeleteObjectsInput(o.Object)
input.Delete.Objects = append(input.Delete.Objects, &s3.ObjectIdentifier{
Key: o.Object.Key,
VersionId: o.Object.VersionId,
})
}
}
}
// iter.Next() could return false (above) plus populate iter.Err()
if iter.Err() != nil {
errs = append(errs, newError(iter.Err(), nil, nil))
}
if input != nil && len(input.Delete.Objects) > 0 {
if err := deleteBatch(ctx, d, input, objects); err != nil {
errs = append(errs, err...)
}
}
if len(errs) > 0 {
return NewBatchError("BatchedDeleteIncomplete", "some objects have failed to be deleted.", errs)
}
return nil
}
func initDeleteObjectsInput(o *s3.DeleteObjectInput) *s3.DeleteObjectsInput {
return &s3.DeleteObjectsInput{
Bucket: o.Bucket,
MFA: o.MFA,
RequestPayer: o.RequestPayer,
Delete: &s3.Delete{},
}
}
const (
// ErrDeleteBatchFailCode represents an error code which will be returned
// only when DeleteObjects.Errors has an error that does not contain a code.
ErrDeleteBatchFailCode = "DeleteBatchError"
errDefaultDeleteBatchMessage = "failed to delete"
)
// deleteBatch will delete a batch of items in the objects parameters.
func deleteBatch(ctx aws.Context, d *BatchDelete, input *s3.DeleteObjectsInput, objects []BatchDeleteObject) []Error {
errs := []Error{}
if result, err := d.Client.DeleteObjectsWithContext(ctx, input); err != nil {
for i := 0; i < len(input.Delete.Objects); i++ {
errs = append(errs, newError(err, input.Bucket, input.Delete.Objects[i].Key))
}
} else if len(result.Errors) > 0 {
for i := 0; i < len(result.Errors); i++ {
code := ErrDeleteBatchFailCode
msg := errDefaultDeleteBatchMessage
if result.Errors[i].Message != nil {
msg = *result.Errors[i].Message
}
if result.Errors[i].Code != nil {
code = *result.Errors[i].Code
}
errs = append(errs, newError(awserr.New(code, msg, err), input.Bucket, result.Errors[i].Key))
}
}
for _, object := range objects {
if object.After == nil {
continue
}
if err := object.After(); err != nil {
errs = append(errs, newError(err, object.Object.Bucket, object.Object.Key))
}
}
return errs
}
func hasParity(o1 *s3.DeleteObjectsInput, o2 BatchDeleteObject) bool {
if o1.Bucket != nil && o2.Object.Bucket != nil {
if *o1.Bucket != *o2.Object.Bucket {
return false
}
} else if o1.Bucket != o2.Object.Bucket {
return false
}
if o1.MFA != nil && o2.Object.MFA != nil {
if *o1.MFA != *o2.Object.MFA {
return false
}
} else if o1.MFA != o2.Object.MFA {
return false
}
if o1.RequestPayer != nil && o2.Object.RequestPayer != nil {
if *o1.RequestPayer != *o2.Object.RequestPayer {
return false
}
} else if o1.RequestPayer != o2.Object.RequestPayer {
return false
}
return true
}
// BatchDownloadIterator is an interface that uses the scanner pattern to iterate
// through a series of objects to be downloaded.
type BatchDownloadIterator interface {
Next() bool
Err() error
DownloadObject() BatchDownloadObject
}
// BatchDownloadObject contains all necessary information to run a batch operation once.
type BatchDownloadObject struct {
Object *s3.GetObjectInput
Writer io.WriterAt
// After will run after each iteration during the batch process. This function will
// be executed whether or not the request was successful.
After func() error
}
// DownloadObjectsIterator implements the BatchDownloadIterator interface and allows for batched
// download of objects.
type DownloadObjectsIterator struct {
Objects []BatchDownloadObject
index int
inc bool
}
// Next will increment the default iterator's index and and ensure that there
// is another object to iterator to.
func (batcher *DownloadObjectsIterator) Next() bool {
if batcher.inc {
batcher.index++
} else {
batcher.inc = true
}
return batcher.index < len(batcher.Objects)
}
// DownloadObject will return the BatchDownloadObject at the current batched index.
func (batcher *DownloadObjectsIterator) DownloadObject() BatchDownloadObject {
object := batcher.Objects[batcher.index]
return object
}
// Err will return an error. Since this is just used to satisfy the BatchDeleteIterator interface
// this will only return nil.
func (batcher *DownloadObjectsIterator) Err() error {
return nil
}
// BatchUploadIterator is an interface that uses the scanner pattern to
// iterate through what needs to be uploaded.
type BatchUploadIterator interface {
Next() bool
Err() error
UploadObject() BatchUploadObject
}
// UploadObjectsIterator implements the BatchUploadIterator interface and allows for batched
// upload of objects.
type UploadObjectsIterator struct {
Objects []BatchUploadObject
index int
inc bool
}
// Next will increment the default iterator's index and and ensure that there
// is another object to iterator to.
func (batcher *UploadObjectsIterator) Next() bool {
if batcher.inc {
batcher.index++
} else {
batcher.inc = true
}
return batcher.index < len(batcher.Objects)
}
// Err will return an error. Since this is just used to satisfy the BatchUploadIterator interface
// this will only return nil.
func (batcher *UploadObjectsIterator) Err() error {
return nil
}
// UploadObject will return the BatchUploadObject at the current batched index.
func (batcher *UploadObjectsIterator) UploadObject() BatchUploadObject {
object := batcher.Objects[batcher.index]
return object
}
// BatchUploadObject contains all necessary information to run a batch operation once.
type BatchUploadObject struct {
Object *UploadInput
// After will run after each iteration during the batch process. This function will
// be executed whether or not the request was successful.
After func() error
}

View File

@@ -0,0 +1,116 @@
// +build go1.7
package s3manager
import (
"context"
"fmt"
"net/http"
"net/http/httptest"
"testing"
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/awserr"
"github.com/aws/aws-sdk-go/aws/request"
"github.com/aws/aws-sdk-go/service/s3"
)
// #1790 bug
func TestBatchDeleteContext(t *testing.T) {
cases := []struct {
objects []BatchDeleteObject
size int
expected int
ctx aws.Context
closeAt int
errCheck func(error) (string, bool)
}{
{
[]BatchDeleteObject{
{
Object: &s3.DeleteObjectInput{
Key: aws.String("1"),
Bucket: aws.String("bucket1"),
},
},
{
Object: &s3.DeleteObjectInput{
Key: aws.String("2"),
Bucket: aws.String("bucket2"),
},
},
{
Object: &s3.DeleteObjectInput{
Key: aws.String("3"),
Bucket: aws.String("bucket3"),
},
},
{
Object: &s3.DeleteObjectInput{
Key: aws.String("4"),
Bucket: aws.String("bucket4"),
},
},
},
1,
0,
aws.BackgroundContext(),
0,
func(err error) (string, bool) {
batchErr, ok := err.(*BatchError)
if !ok {
return "not BatchError type", false
}
errs := batchErr.Errors
if len(errs) != 4 {
return fmt.Sprintf("expected 1, but received %d", len(errs)), false
}
for _, tempErr := range errs {
aerr, ok := tempErr.OrigErr.(awserr.Error)
if !ok {
return "not awserr.Error type", false
}
if code := aerr.Code(); code != request.CanceledErrorCode {
return fmt.Sprintf("expected %q, but received %q", request.CanceledErrorCode, code), false
}
}
return "", true
},
},
}
count := 0
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusNoContent)
count++
}))
svc := &mockS3Client{S3: buildS3SvcClient(server.URL)}
for i, c := range cases {
ctx, done := context.WithCancel(c.ctx)
defer done()
if i == c.closeAt {
done()
}
batcher := BatchDelete{
Client: svc,
BatchSize: c.size,
}
err := batcher.Delete(ctx, &DeleteObjectsIterator{Objects: c.objects})
if msg, ok := c.errCheck(err); !ok {
t.Error(msg)
}
if count != c.expected {
t.Errorf("Case %d: expected %d, but received %d", i, c.expected, count)
}
count = 0
}
}

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,88 @@
package s3manager
import (
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/client"
"github.com/aws/aws-sdk-go/aws/credentials"
"github.com/aws/aws-sdk-go/aws/request"
"github.com/aws/aws-sdk-go/service/s3"
"github.com/aws/aws-sdk-go/service/s3/s3iface"
)
// GetBucketRegion will attempt to get the region for a bucket using the
// regionHint to determine which AWS partition to perform the query on.
//
// The request will not be signed, and will not use your AWS credentials.
//
// A "NotFound" error code will be returned if the bucket does not exist in the
// AWS partition the regionHint belongs to. If the regionHint parameter is an
// empty string GetBucketRegion will fallback to the ConfigProvider's region
// config. If the regionHint is empty, and the ConfigProvider does not have a
// region value, an error will be returned..
//
// For example to get the region of a bucket which exists in "eu-central-1"
// you could provide a region hint of "us-west-2".
//
// sess := session.Must(session.NewSession())
//
// bucket := "my-bucket"
// region, err := s3manager.GetBucketRegion(ctx, sess, bucket, "us-west-2")
// if err != nil {
// if aerr, ok := err.(awserr.Error); ok && aerr.Code() == "NotFound" {
// fmt.Fprintf(os.Stderr, "unable to find bucket %s's region not found\n", bucket)
// }
// return err
// }
// fmt.Printf("Bucket %s is in %s region\n", bucket, region)
//
func GetBucketRegion(ctx aws.Context, c client.ConfigProvider, bucket, regionHint string, opts ...request.Option) (string, error) {
var cfg aws.Config
if len(regionHint) != 0 {
cfg.Region = aws.String(regionHint)
}
svc := s3.New(c, &cfg)
return GetBucketRegionWithClient(ctx, svc, bucket, opts...)
}
const bucketRegionHeader = "X-Amz-Bucket-Region"
// GetBucketRegionWithClient is the same as GetBucketRegion with the exception
// that it takes a S3 service client instead of a Session. The regionHint is
// derived from the region the S3 service client was created in.
//
// See GetBucketRegion for more information.
func GetBucketRegionWithClient(ctx aws.Context, svc s3iface.S3API, bucket string, opts ...request.Option) (string, error) {
req, _ := svc.HeadBucketRequest(&s3.HeadBucketInput{
Bucket: aws.String(bucket),
})
req.Config.S3ForcePathStyle = aws.Bool(true)
req.Config.Credentials = credentials.AnonymousCredentials
req.SetContext(ctx)
// Disable HTTP redirects to prevent an invalid 301 from eating the response
// because Go's HTTP client will fail, and drop the response if an 301 is
// received without a location header. S3 will return a 301 without the
// location header for HeadObject API calls.
req.DisableFollowRedirects = true
var bucketRegion string
req.Handlers.Send.PushBack(func(r *request.Request) {
bucketRegion = r.HTTPResponse.Header.Get(bucketRegionHeader)
if len(bucketRegion) == 0 {
return
}
r.HTTPResponse.StatusCode = 200
r.HTTPResponse.Status = "OK"
r.Error = nil
})
req.ApplyOptions(opts...)
if err := req.Send(); err != nil {
return "", err
}
bucketRegion = s3.NormalizeBucketLocation(bucketRegion)
return bucketRegion, nil
}

View File

@@ -0,0 +1,96 @@
package s3manager
import (
"net/http"
"net/http/httptest"
"testing"
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/awserr"
"github.com/aws/aws-sdk-go/awstesting/unit"
"github.com/aws/aws-sdk-go/service/s3"
)
func testSetupGetBucketRegionServer(region string, statusCode int, incHeader bool) *httptest.Server {
return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if incHeader {
w.Header().Set(bucketRegionHeader, region)
}
w.WriteHeader(statusCode)
}))
}
var testGetBucketRegionCases = []struct {
RespRegion string
StatusCode int
HintRegion string
ExpectReqRegion string
}{
{"bucket-region", 301, "hint-region", ""},
{"bucket-region", 403, "hint-region", ""},
{"bucket-region", 200, "hint-region", ""},
{"bucket-region", 200, "", "default-region"},
}
func TestGetBucketRegion_Exists(t *testing.T) {
for i, c := range testGetBucketRegionCases {
server := testSetupGetBucketRegionServer(c.RespRegion, c.StatusCode, true)
sess := unit.Session.Copy()
sess.Config.Region = aws.String("default-region")
sess.Config.Endpoint = aws.String(server.URL)
sess.Config.DisableSSL = aws.Bool(true)
ctx := aws.BackgroundContext()
region, err := GetBucketRegion(ctx, sess, "bucket", c.HintRegion)
if err != nil {
t.Fatalf("%d, expect no error, got %v", i, err)
}
if e, a := c.RespRegion, region; e != a {
t.Errorf("%d, expect %q region, got %q", i, e, a)
}
}
}
func TestGetBucketRegion_NotExists(t *testing.T) {
server := testSetupGetBucketRegionServer("ignore-region", 404, false)
sess := unit.Session.Copy()
sess.Config.Endpoint = aws.String(server.URL)
sess.Config.DisableSSL = aws.Bool(true)
ctx := aws.BackgroundContext()
region, err := GetBucketRegion(ctx, sess, "bucket", "hint-region")
if err == nil {
t.Fatalf("expect error, but did not get one")
}
aerr := err.(awserr.Error)
if e, a := "NotFound", aerr.Code(); e != a {
t.Errorf("expect %s error code, got %s", e, a)
}
if len(region) != 0 {
t.Errorf("expect region not to be set, got %q", region)
}
}
func TestGetBucketRegionWithClient(t *testing.T) {
for i, c := range testGetBucketRegionCases {
server := testSetupGetBucketRegionServer(c.RespRegion, c.StatusCode, true)
svc := s3.New(unit.Session, &aws.Config{
Region: aws.String("hint-region"),
Endpoint: aws.String(server.URL),
DisableSSL: aws.Bool(true),
})
ctx := aws.BackgroundContext()
region, err := GetBucketRegionWithClient(ctx, svc, "bucket")
if err != nil {
t.Fatalf("%d, expect no error, got %v", i, err)
}
if e, a := c.RespRegion, region; e != a {
t.Errorf("%d, expect %q region, got %q", i, e, a)
}
}
}

View File

@@ -0,0 +1,3 @@
// Package s3manager provides utilities to upload and download objects from
// S3 concurrently. Helpful for when working with large objects.
package s3manager

View File

@@ -0,0 +1,555 @@
package s3manager
import (
"fmt"
"io"
"net/http"
"strconv"
"strings"
"sync"
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/awserr"
"github.com/aws/aws-sdk-go/aws/awsutil"
"github.com/aws/aws-sdk-go/aws/client"
"github.com/aws/aws-sdk-go/aws/request"
"github.com/aws/aws-sdk-go/service/s3"
"github.com/aws/aws-sdk-go/service/s3/s3iface"
)
// DefaultDownloadPartSize is the default range of bytes to get at a time when
// using Download().
const DefaultDownloadPartSize = 1024 * 1024 * 5
// DefaultDownloadConcurrency is the default number of goroutines to spin up
// when using Download().
const DefaultDownloadConcurrency = 5
// The Downloader structure that calls Download(). It is safe to call Download()
// on this structure for multiple objects and across concurrent goroutines.
// Mutating the Downloader's properties is not safe to be done concurrently.
type Downloader struct {
// The buffer size (in bytes) to use when buffering data into chunks and
// sending them as parts to S3. The minimum allowed part size is 5MB, and
// if this value is set to zero, the DefaultDownloadPartSize value will be used.
//
// PartSize is ignored if the Range input parameter is provided.
PartSize int64
// The number of goroutines to spin up in parallel when sending parts.
// If this is set to zero, the DefaultDownloadConcurrency value will be used.
//
// Concurrency of 1 will download the parts sequentially.
//
// Concurrency is ignored if the Range input parameter is provided.
Concurrency int
// An S3 client to use when performing downloads.
S3 s3iface.S3API
// List of request options that will be passed down to individual API
// operation requests made by the downloader.
RequestOptions []request.Option
}
// WithDownloaderRequestOptions appends to the Downloader's API request options.
func WithDownloaderRequestOptions(opts ...request.Option) func(*Downloader) {
return func(d *Downloader) {
d.RequestOptions = append(d.RequestOptions, opts...)
}
}
// NewDownloader creates a new Downloader instance to downloads objects from
// S3 in concurrent chunks. Pass in additional functional options to customize
// the downloader behavior. Requires a client.ConfigProvider in order to create
// a S3 service client. The session.Session satisfies the client.ConfigProvider
// interface.
//
// Example:
// // The session the S3 Downloader will use
// sess := session.Must(session.NewSession())
//
// // Create a downloader with the session and default options
// downloader := s3manager.NewDownloader(sess)
//
// // Create a downloader with the session and custom options
// downloader := s3manager.NewDownloader(sess, func(d *s3manager.Downloader) {
// d.PartSize = 64 * 1024 * 1024 // 64MB per part
// })
func NewDownloader(c client.ConfigProvider, options ...func(*Downloader)) *Downloader {
d := &Downloader{
S3: s3.New(c),
PartSize: DefaultDownloadPartSize,
Concurrency: DefaultDownloadConcurrency,
}
for _, option := range options {
option(d)
}
return d
}
// NewDownloaderWithClient creates a new Downloader instance to downloads
// objects from S3 in concurrent chunks. Pass in additional functional
// options to customize the downloader behavior. Requires a S3 service client
// to make S3 API calls.
//
// Example:
// // The session the S3 Downloader will use
// sess := session.Must(session.NewSession())
//
// // The S3 client the S3 Downloader will use
// s3Svc := s3.new(sess)
//
// // Create a downloader with the s3 client and default options
// downloader := s3manager.NewDownloaderWithClient(s3Svc)
//
// // Create a downloader with the s3 client and custom options
// downloader := s3manager.NewDownloaderWithClient(s3Svc, func(d *s3manager.Downloader) {
// d.PartSize = 64 * 1024 * 1024 // 64MB per part
// })
func NewDownloaderWithClient(svc s3iface.S3API, options ...func(*Downloader)) *Downloader {
d := &Downloader{
S3: svc,
PartSize: DefaultDownloadPartSize,
Concurrency: DefaultDownloadConcurrency,
}
for _, option := range options {
option(d)
}
return d
}
type maxRetrier interface {
MaxRetries() int
}
// Download downloads an object in S3 and writes the payload into w using
// concurrent GET requests.
//
// Additional functional options can be provided to configure the individual
// download. These options are copies of the Downloader instance Download is called from.
// Modifying the options will not impact the original Downloader instance.
//
// It is safe to call this method concurrently across goroutines.
//
// The w io.WriterAt can be satisfied by an os.File to do multipart concurrent
// downloads, or in memory []byte wrapper using aws.WriteAtBuffer.
//
// Specifying a Downloader.Concurrency of 1 will cause the Downloader to
// download the parts from S3 sequentially.
//
// If the GetObjectInput's Range value is provided that will cause the downloader
// to perform a single GetObjectInput request for that object's range. This will
// caused the part size, and concurrency configurations to be ignored.
func (d Downloader) Download(w io.WriterAt, input *s3.GetObjectInput, options ...func(*Downloader)) (n int64, err error) {
return d.DownloadWithContext(aws.BackgroundContext(), w, input, options...)
}
// DownloadWithContext downloads an object in S3 and writes the payload into w
// using concurrent GET requests.
//
// DownloadWithContext is the same as Download with the additional support for
// Context input parameters. The Context must not be nil. A nil Context will
// cause a panic. Use the Context to add deadlining, timeouts, etc. The
// DownloadWithContext may create sub-contexts for individual underlying
// requests.
//
// Additional functional options can be provided to configure the individual
// download. These options are copies of the Downloader instance Download is
// called from. Modifying the options will not impact the original Downloader
// instance. Use the WithDownloaderRequestOptions helper function to pass in request
// options that will be applied to all API operations made with this downloader.
//
// The w io.WriterAt can be satisfied by an os.File to do multipart concurrent
// downloads, or in memory []byte wrapper using aws.WriteAtBuffer.
//
// Specifying a Downloader.Concurrency of 1 will cause the Downloader to
// download the parts from S3 sequentially.
//
// It is safe to call this method concurrently across goroutines.
//
// If the GetObjectInput's Range value is provided that will cause the downloader
// to perform a single GetObjectInput request for that object's range. This will
// caused the part size, and concurrency configurations to be ignored.
func (d Downloader) DownloadWithContext(ctx aws.Context, w io.WriterAt, input *s3.GetObjectInput, options ...func(*Downloader)) (n int64, err error) {
impl := downloader{w: w, in: input, cfg: d, ctx: ctx}
for _, option := range options {
option(&impl.cfg)
}
impl.cfg.RequestOptions = append(impl.cfg.RequestOptions, request.WithAppendUserAgent("S3Manager"))
if s, ok := d.S3.(maxRetrier); ok {
impl.partBodyMaxRetries = s.MaxRetries()
}
impl.totalBytes = -1
if impl.cfg.Concurrency == 0 {
impl.cfg.Concurrency = DefaultDownloadConcurrency
}
if impl.cfg.PartSize == 0 {
impl.cfg.PartSize = DefaultDownloadPartSize
}
return impl.download()
}
// DownloadWithIterator will download a batched amount of objects in S3 and writes them
// to the io.WriterAt specificed in the iterator.
//
// Example:
// svc := s3manager.NewDownloader(session)
//
// fooFile, err := os.Open("/tmp/foo.file")
// if err != nil {
// return err
// }
//
// barFile, err := os.Open("/tmp/bar.file")
// if err != nil {
// return err
// }
//
// objects := []s3manager.BatchDownloadObject {
// {
// Input: &s3.GetObjectInput {
// Bucket: aws.String("bucket"),
// Key: aws.String("foo"),
// },
// Writer: fooFile,
// },
// {
// Input: &s3.GetObjectInput {
// Bucket: aws.String("bucket"),
// Key: aws.String("bar"),
// },
// Writer: barFile,
// },
// }
//
// iter := &s3manager.DownloadObjectsIterator{Objects: objects}
// if err := svc.DownloadWithIterator(aws.BackgroundContext(), iter); err != nil {
// return err
// }
func (d Downloader) DownloadWithIterator(ctx aws.Context, iter BatchDownloadIterator, opts ...func(*Downloader)) error {
var errs []Error
for iter.Next() {
object := iter.DownloadObject()
if _, err := d.DownloadWithContext(ctx, object.Writer, object.Object, opts...); err != nil {
errs = append(errs, newError(err, object.Object.Bucket, object.Object.Key))
}
if object.After == nil {
continue
}
if err := object.After(); err != nil {
errs = append(errs, newError(err, object.Object.Bucket, object.Object.Key))
}
}
if len(errs) > 0 {
return NewBatchError("BatchedDownloadIncomplete", "some objects have failed to download.", errs)
}
return nil
}
// downloader is the implementation structure used internally by Downloader.
type downloader struct {
ctx aws.Context
cfg Downloader
in *s3.GetObjectInput
w io.WriterAt
wg sync.WaitGroup
m sync.Mutex
pos int64
totalBytes int64
written int64
err error
partBodyMaxRetries int
}
// download performs the implementation of the object download across ranged
// GETs.
func (d *downloader) download() (n int64, err error) {
// If range is specified fall back to single download of that range
// this enables the functionality of ranged gets with the downloader but
// at the cost of no multipart downloads.
if rng := aws.StringValue(d.in.Range); len(rng) > 0 {
d.downloadRange(rng)
return d.written, d.err
}
// Spin off first worker to check additional header information
d.getChunk()
if total := d.getTotalBytes(); total >= 0 {
// Spin up workers
ch := make(chan dlchunk, d.cfg.Concurrency)
for i := 0; i < d.cfg.Concurrency; i++ {
d.wg.Add(1)
go d.downloadPart(ch)
}
// Assign work
for d.getErr() == nil {
if d.pos >= total {
break // We're finished queuing chunks
}
// Queue the next range of bytes to read.
ch <- dlchunk{w: d.w, start: d.pos, size: d.cfg.PartSize}
d.pos += d.cfg.PartSize
}
// Wait for completion
close(ch)
d.wg.Wait()
} else {
// Checking if we read anything new
for d.err == nil {
d.getChunk()
}
// We expect a 416 error letting us know we are done downloading the
// total bytes. Since we do not know the content's length, this will
// keep grabbing chunks of data until the range of bytes specified in
// the request is out of range of the content. Once, this happens, a
// 416 should occur.
e, ok := d.err.(awserr.RequestFailure)
if ok && e.StatusCode() == http.StatusRequestedRangeNotSatisfiable {
d.err = nil
}
}
// Return error
return d.written, d.err
}
// downloadPart is an individual goroutine worker reading from the ch channel
// and performing a GetObject request on the data with a given byte range.
//
// If this is the first worker, this operation also resolves the total number
// of bytes to be read so that the worker manager knows when it is finished.
func (d *downloader) downloadPart(ch chan dlchunk) {
defer d.wg.Done()
for {
chunk, ok := <-ch
if !ok {
break
}
if d.getErr() != nil {
// Drain the channel if there is an error, to prevent deadlocking
// of download producer.
continue
}
if err := d.downloadChunk(chunk); err != nil {
d.setErr(err)
}
}
}
// getChunk grabs a chunk of data from the body.
// Not thread safe. Should only used when grabbing data on a single thread.
func (d *downloader) getChunk() {
if d.getErr() != nil {
return
}
chunk := dlchunk{w: d.w, start: d.pos, size: d.cfg.PartSize}
d.pos += d.cfg.PartSize
if err := d.downloadChunk(chunk); err != nil {
d.setErr(err)
}
}
// downloadRange downloads an Object given the passed in Byte-Range value.
// The chunk used down download the range will be configured for that range.
func (d *downloader) downloadRange(rng string) {
if d.getErr() != nil {
return
}
chunk := dlchunk{w: d.w, start: d.pos}
// Ranges specified will short circuit the multipart download
chunk.withRange = rng
if err := d.downloadChunk(chunk); err != nil {
d.setErr(err)
}
// Update the position based on the amount of data received.
d.pos = d.written
}
// downloadChunk downloads the chunk from s3
func (d *downloader) downloadChunk(chunk dlchunk) error {
in := &s3.GetObjectInput{}
awsutil.Copy(in, d.in)
// Get the next byte range of data
in.Range = aws.String(chunk.ByteRange())
var n int64
var err error
for retry := 0; retry <= d.partBodyMaxRetries; retry++ {
var resp *s3.GetObjectOutput
resp, err = d.cfg.S3.GetObjectWithContext(d.ctx, in, d.cfg.RequestOptions...)
if err != nil {
return err
}
d.setTotalBytes(resp) // Set total if not yet set.
n, err = io.Copy(&chunk, resp.Body)
resp.Body.Close()
if err == nil {
break
}
chunk.cur = 0
logMessage(d.cfg.S3, aws.LogDebugWithRequestRetries,
fmt.Sprintf("DEBUG: object part body download interrupted %s, err, %v, retrying attempt %d",
aws.StringValue(in.Key), err, retry))
}
d.incrWritten(n)
return err
}
func logMessage(svc s3iface.S3API, level aws.LogLevelType, msg string) {
s, ok := svc.(*s3.S3)
if !ok {
return
}
if s.Config.Logger == nil {
return
}
if s.Config.LogLevel.Matches(level) {
s.Config.Logger.Log(msg)
}
}
// getTotalBytes is a thread-safe getter for retrieving the total byte status.
func (d *downloader) getTotalBytes() int64 {
d.m.Lock()
defer d.m.Unlock()
return d.totalBytes
}
// setTotalBytes is a thread-safe setter for setting the total byte status.
// Will extract the object's total bytes from the Content-Range if the file
// will be chunked, or Content-Length. Content-Length is used when the response
// does not include a Content-Range. Meaning the object was not chunked. This
// occurs when the full file fits within the PartSize directive.
func (d *downloader) setTotalBytes(resp *s3.GetObjectOutput) {
d.m.Lock()
defer d.m.Unlock()
if d.totalBytes >= 0 {
return
}
if resp.ContentRange == nil {
// ContentRange is nil when the full file contents is provided, and
// is not chunked. Use ContentLength instead.
if resp.ContentLength != nil {
d.totalBytes = *resp.ContentLength
return
}
} else {
parts := strings.Split(*resp.ContentRange, "/")
total := int64(-1)
var err error
// Checking for whether or not a numbered total exists
// If one does not exist, we will assume the total to be -1, undefined,
// and sequentially download each chunk until hitting a 416 error
totalStr := parts[len(parts)-1]
if totalStr != "*" {
total, err = strconv.ParseInt(totalStr, 10, 64)
if err != nil {
d.err = err
return
}
}
d.totalBytes = total
}
}
func (d *downloader) incrWritten(n int64) {
d.m.Lock()
defer d.m.Unlock()
d.written += n
}
// getErr is a thread-safe getter for the error object
func (d *downloader) getErr() error {
d.m.Lock()
defer d.m.Unlock()
return d.err
}
// setErr is a thread-safe setter for the error object
func (d *downloader) setErr(e error) {
d.m.Lock()
defer d.m.Unlock()
d.err = e
}
// dlchunk represents a single chunk of data to write by the worker routine.
// This structure also implements an io.SectionReader style interface for
// io.WriterAt, effectively making it an io.SectionWriter (which does not
// exist).
type dlchunk struct {
w io.WriterAt
start int64
size int64
cur int64
// specifies the byte range the chunk should be downloaded with.
withRange string
}
// Write wraps io.WriterAt for the dlchunk, writing from the dlchunk's start
// position to its end (or EOF).
//
// If a range is specified on the dlchunk the size will be ignored when writing.
// as the total size may not of be known ahead of time.
func (c *dlchunk) Write(p []byte) (n int, err error) {
if c.cur >= c.size && len(c.withRange) == 0 {
return 0, io.EOF
}
n, err = c.w.WriteAt(p, c.start+c.cur)
c.cur += int64(n)
return
}
// ByteRange returns a HTTP Byte-Range header value that should be used by the
// client to request the chunk's range.
func (c *dlchunk) ByteRange() string {
if len(c.withRange) != 0 {
return c.withRange
}
return fmt.Sprintf("bytes=%d-%d", c.start, c.start+c.size-1)
}

View File

@@ -0,0 +1,658 @@
package s3manager_test
import (
"bytes"
"fmt"
"io"
"io/ioutil"
"net/http"
"reflect"
"regexp"
"strconv"
"strings"
"sync"
"testing"
"time"
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/awserr"
"github.com/aws/aws-sdk-go/aws/request"
"github.com/aws/aws-sdk-go/awstesting"
"github.com/aws/aws-sdk-go/awstesting/unit"
"github.com/aws/aws-sdk-go/service/s3"
"github.com/aws/aws-sdk-go/service/s3/s3manager"
)
func dlLoggingSvc(data []byte) (*s3.S3, *[]string, *[]string) {
var m sync.Mutex
names := []string{}
ranges := []string{}
svc := s3.New(unit.Session)
svc.Handlers.Send.Clear()
svc.Handlers.Send.PushBack(func(r *request.Request) {
m.Lock()
defer m.Unlock()
names = append(names, r.Operation.Name)
ranges = append(ranges, *r.Params.(*s3.GetObjectInput).Range)
rerng := regexp.MustCompile(`bytes=(\d+)-(\d+)`)
rng := rerng.FindStringSubmatch(r.HTTPRequest.Header.Get("Range"))
start, _ := strconv.ParseInt(rng[1], 10, 64)
fin, _ := strconv.ParseInt(rng[2], 10, 64)
fin++
if fin > int64(len(data)) {
fin = int64(len(data))
}
bodyBytes := data[start:fin]
r.HTTPResponse = &http.Response{
StatusCode: 200,
Body: ioutil.NopCloser(bytes.NewReader(bodyBytes)),
Header: http.Header{},
}
r.HTTPResponse.Header.Set("Content-Range", fmt.Sprintf("bytes %d-%d/%d",
start, fin-1, len(data)))
r.HTTPResponse.Header.Set("Content-Length", fmt.Sprintf("%d", len(bodyBytes)))
})
return svc, &names, &ranges
}
func dlLoggingSvcNoChunk(data []byte) (*s3.S3, *[]string) {
var m sync.Mutex
names := []string{}
svc := s3.New(unit.Session)
svc.Handlers.Send.Clear()
svc.Handlers.Send.PushBack(func(r *request.Request) {
m.Lock()
defer m.Unlock()
names = append(names, r.Operation.Name)
r.HTTPResponse = &http.Response{
StatusCode: 200,
Body: ioutil.NopCloser(bytes.NewReader(data[:])),
Header: http.Header{},
}
r.HTTPResponse.Header.Set("Content-Length", fmt.Sprintf("%d", len(data)))
})
return svc, &names
}
func dlLoggingSvcNoContentRangeLength(data []byte, states []int) (*s3.S3, *[]string) {
var m sync.Mutex
names := []string{}
var index int = 0
svc := s3.New(unit.Session)
svc.Handlers.Send.Clear()
svc.Handlers.Send.PushBack(func(r *request.Request) {
m.Lock()
defer m.Unlock()
names = append(names, r.Operation.Name)
r.HTTPResponse = &http.Response{
StatusCode: states[index],
Body: ioutil.NopCloser(bytes.NewReader(data[:])),
Header: http.Header{},
}
index++
})
return svc, &names
}
func dlLoggingSvcContentRangeTotalAny(data []byte, states []int) (*s3.S3, *[]string) {
var m sync.Mutex
names := []string{}
ranges := []string{}
var index int = 0
svc := s3.New(unit.Session)
svc.Handlers.Send.Clear()
svc.Handlers.Send.PushBack(func(r *request.Request) {
m.Lock()
defer m.Unlock()
names = append(names, r.Operation.Name)
ranges = append(ranges, *r.Params.(*s3.GetObjectInput).Range)
rerng := regexp.MustCompile(`bytes=(\d+)-(\d+)`)
rng := rerng.FindStringSubmatch(r.HTTPRequest.Header.Get("Range"))
start, _ := strconv.ParseInt(rng[1], 10, 64)
fin, _ := strconv.ParseInt(rng[2], 10, 64)
fin++
if fin >= int64(len(data)) {
fin = int64(len(data))
}
// Setting start and finish to 0 because this state of 1 is suppose to
// be an error state of 416
if index == len(states)-1 {
start = 0
fin = 0
}
bodyBytes := data[start:fin]
r.HTTPResponse = &http.Response{
StatusCode: states[index],
Body: ioutil.NopCloser(bytes.NewReader(bodyBytes)),
Header: http.Header{},
}
r.HTTPResponse.Header.Set("Content-Range", fmt.Sprintf("bytes %d-%d/*",
start, fin-1))
index++
})
return svc, &names
}
func dlLoggingSvcWithErrReader(cases []testErrReader) (*s3.S3, *[]string) {
var m sync.Mutex
names := []string{}
var index int = 0
svc := s3.New(unit.Session, &aws.Config{
MaxRetries: aws.Int(len(cases) - 1),
})
svc.Handlers.Send.Clear()
svc.Handlers.Send.PushBack(func(r *request.Request) {
m.Lock()
defer m.Unlock()
names = append(names, r.Operation.Name)
c := cases[index]
r.HTTPResponse = &http.Response{
StatusCode: http.StatusOK,
Body: ioutil.NopCloser(&c),
Header: http.Header{},
}
r.HTTPResponse.Header.Set("Content-Range",
fmt.Sprintf("bytes %d-%d/%d", 0, c.Len-1, c.Len))
r.HTTPResponse.Header.Set("Content-Length", fmt.Sprintf("%d", c.Len))
index++
})
return svc, &names
}
func TestDownloadOrder(t *testing.T) {
s, names, ranges := dlLoggingSvc(buf12MB)
d := s3manager.NewDownloaderWithClient(s, func(d *s3manager.Downloader) {
d.Concurrency = 1
})
w := &aws.WriteAtBuffer{}
n, err := d.Download(w, &s3.GetObjectInput{
Bucket: aws.String("bucket"),
Key: aws.String("key"),
})
if err != nil {
t.Fatalf("expect no error, got %v", err)
}
if e, a := int64(len(buf12MB)), n; e != a {
t.Errorf("expect %d buffer length, got %d", e, a)
}
expectCalls := []string{"GetObject", "GetObject", "GetObject"}
if e, a := expectCalls, *names; !reflect.DeepEqual(e, a) {
t.Errorf("expect %v API calls, got %v", e, a)
}
expectRngs := []string{"bytes=0-5242879", "bytes=5242880-10485759", "bytes=10485760-15728639"}
if e, a := expectRngs, *ranges; !reflect.DeepEqual(e, a) {
t.Errorf("expect %v ranges, got %v", e, a)
}
count := 0
for _, b := range w.Bytes() {
count += int(b)
}
if count != 0 {
t.Errorf("expect 0 count, got %d", count)
}
}
func TestDownloadZero(t *testing.T) {
s, names, ranges := dlLoggingSvc([]byte{})
d := s3manager.NewDownloaderWithClient(s)
w := &aws.WriteAtBuffer{}
n, err := d.Download(w, &s3.GetObjectInput{
Bucket: aws.String("bucket"),
Key: aws.String("key"),
})
if err != nil {
t.Fatalf("expect no error, got %v", err)
}
if n != 0 {
t.Errorf("expect 0 bytes read, got %d", n)
}
expectCalls := []string{"GetObject"}
if e, a := expectCalls, *names; !reflect.DeepEqual(e, a) {
t.Errorf("expect %v API calls, got %v", e, a)
}
expectRngs := []string{"bytes=0-5242879"}
if e, a := expectRngs, *ranges; !reflect.DeepEqual(e, a) {
t.Errorf("expect %v ranges, got %v", e, a)
}
}
func TestDownloadSetPartSize(t *testing.T) {
s, names, ranges := dlLoggingSvc([]byte{1, 2, 3})
d := s3manager.NewDownloaderWithClient(s, func(d *s3manager.Downloader) {
d.Concurrency = 1
d.PartSize = 1
})
w := &aws.WriteAtBuffer{}
n, err := d.Download(w, &s3.GetObjectInput{
Bucket: aws.String("bucket"),
Key: aws.String("key"),
})
if err != nil {
t.Fatalf("expect no error, got %v", err)
}
if e, a := int64(3), n; e != a {
t.Errorf("expect %d bytes read, got %d", e, a)
}
expectCalls := []string{"GetObject", "GetObject", "GetObject"}
if e, a := expectCalls, *names; !reflect.DeepEqual(e, a) {
t.Errorf("expect %v API calls, got %v", e, a)
}
expectRngs := []string{"bytes=0-0", "bytes=1-1", "bytes=2-2"}
if e, a := expectRngs, *ranges; !reflect.DeepEqual(e, a) {
t.Errorf("expect %v ranges, got %v", e, a)
}
expectBytes := []byte{1, 2, 3}
if e, a := expectBytes, w.Bytes(); !reflect.DeepEqual(e, a) {
t.Errorf("expect %v bytes, got %v", e, a)
}
}
func TestDownloadError(t *testing.T) {
s, names, _ := dlLoggingSvc([]byte{1, 2, 3})
num := 0
s.Handlers.Send.PushBack(func(r *request.Request) {
num++
if num > 1 {
r.HTTPResponse.StatusCode = 400
r.HTTPResponse.Body = ioutil.NopCloser(bytes.NewReader([]byte{}))
}
})
d := s3manager.NewDownloaderWithClient(s, func(d *s3manager.Downloader) {
d.Concurrency = 1
d.PartSize = 1
})
w := &aws.WriteAtBuffer{}
n, err := d.Download(w, &s3.GetObjectInput{
Bucket: aws.String("bucket"),
Key: aws.String("key"),
})
if err == nil {
t.Fatalf("expect error, got none")
}
aerr := err.(awserr.Error)
if e, a := "BadRequest", aerr.Code(); e != a {
t.Errorf("expect %s error code, got %s", e, a)
}
if e, a := int64(1), n; e != a {
t.Errorf("expect %d bytes read, got %d", e, a)
}
expectCalls := []string{"GetObject", "GetObject"}
if e, a := expectCalls, *names; !reflect.DeepEqual(e, a) {
t.Errorf("expect %v API calls, got %v", e, a)
}
expectBytes := []byte{1}
if e, a := expectBytes, w.Bytes(); !reflect.DeepEqual(e, a) {
t.Errorf("expect %v bytes, got %v", e, a)
}
}
func TestDownloadNonChunk(t *testing.T) {
s, names := dlLoggingSvcNoChunk(buf2MB)
d := s3manager.NewDownloaderWithClient(s, func(d *s3manager.Downloader) {
d.Concurrency = 1
})
w := &aws.WriteAtBuffer{}
n, err := d.Download(w, &s3.GetObjectInput{
Bucket: aws.String("bucket"),
Key: aws.String("key"),
})
if err != nil {
t.Fatalf("expect no error, got %v", err)
}
if e, a := int64(len(buf2MB)), n; e != a {
t.Errorf("expect %d bytes read, got %d", e, a)
}
expectCalls := []string{"GetObject"}
if e, a := expectCalls, *names; !reflect.DeepEqual(e, a) {
t.Errorf("expect %v API calls, got %v", e, a)
}
count := 0
for _, b := range w.Bytes() {
count += int(b)
}
if count != 0 {
t.Errorf("expect 0 count, got %d", count)
}
}
func TestDownloadNoContentRangeLength(t *testing.T) {
s, names := dlLoggingSvcNoContentRangeLength(buf2MB, []int{200, 416})
d := s3manager.NewDownloaderWithClient(s, func(d *s3manager.Downloader) {
d.Concurrency = 1
})
w := &aws.WriteAtBuffer{}
n, err := d.Download(w, &s3.GetObjectInput{
Bucket: aws.String("bucket"),
Key: aws.String("key"),
})
if err != nil {
t.Fatalf("expect no error, got %v", err)
}
if e, a := int64(len(buf2MB)), n; e != a {
t.Errorf("expect %d bytes read, got %d", e, a)
}
expectCalls := []string{"GetObject", "GetObject"}
if e, a := expectCalls, *names; !reflect.DeepEqual(e, a) {
t.Errorf("expect %v API calls, got %v", e, a)
}
count := 0
for _, b := range w.Bytes() {
count += int(b)
}
if count != 0 {
t.Errorf("expect 0 count, got %d", count)
}
}
func TestDownloadContentRangeTotalAny(t *testing.T) {
s, names := dlLoggingSvcContentRangeTotalAny(buf2MB, []int{200, 416})
d := s3manager.NewDownloaderWithClient(s, func(d *s3manager.Downloader) {
d.Concurrency = 1
})
w := &aws.WriteAtBuffer{}
n, err := d.Download(w, &s3.GetObjectInput{
Bucket: aws.String("bucket"),
Key: aws.String("key"),
})
if err != nil {
t.Fatalf("expect no error, got %v", err)
}
if e, a := int64(len(buf2MB)), n; e != a {
t.Errorf("expect %d bytes read, got %d", e, a)
}
expectCalls := []string{"GetObject", "GetObject"}
if e, a := expectCalls, *names; !reflect.DeepEqual(e, a) {
t.Errorf("expect %v API calls, got %v", e, a)
}
count := 0
for _, b := range w.Bytes() {
count += int(b)
}
if count != 0 {
t.Errorf("expect 0 count, got %d", count)
}
}
func TestDownloadPartBodyRetry_SuccessRetry(t *testing.T) {
s, names := dlLoggingSvcWithErrReader([]testErrReader{
{Buf: []byte("ab"), Len: 3, Err: io.ErrUnexpectedEOF},
{Buf: []byte("123"), Len: 3, Err: io.EOF},
})
d := s3manager.NewDownloaderWithClient(s, func(d *s3manager.Downloader) {
d.Concurrency = 1
})
w := &aws.WriteAtBuffer{}
n, err := d.Download(w, &s3.GetObjectInput{
Bucket: aws.String("bucket"),
Key: aws.String("key"),
})
if err != nil {
t.Fatalf("expect no error, got %v", err)
}
if e, a := int64(3), n; e != a {
t.Errorf("expect %d bytes read, got %d", e, a)
}
expectCalls := []string{"GetObject", "GetObject"}
if e, a := expectCalls, *names; !reflect.DeepEqual(e, a) {
t.Errorf("expect %v API calls, got %v", e, a)
}
if e, a := "123", string(w.Bytes()); e != a {
t.Errorf("expect %q response, got %q", e, a)
}
}
func TestDownloadPartBodyRetry_SuccessNoRetry(t *testing.T) {
s, names := dlLoggingSvcWithErrReader([]testErrReader{
{Buf: []byte("abc"), Len: 3, Err: io.EOF},
})
d := s3manager.NewDownloaderWithClient(s, func(d *s3manager.Downloader) {
d.Concurrency = 1
})
w := &aws.WriteAtBuffer{}
n, err := d.Download(w, &s3.GetObjectInput{
Bucket: aws.String("bucket"),
Key: aws.String("key"),
})
if err != nil {
t.Fatalf("expect no error, got %v", err)
}
if e, a := int64(3), n; e != a {
t.Errorf("expect %d bytes read, got %d", e, a)
}
expectCalls := []string{"GetObject"}
if e, a := expectCalls, *names; !reflect.DeepEqual(e, a) {
t.Errorf("expect %v API calls, got %v", e, a)
}
if e, a := "abc", string(w.Bytes()); e != a {
t.Errorf("expect %q response, got %q", e, a)
}
}
func TestDownloadPartBodyRetry_FailRetry(t *testing.T) {
s, names := dlLoggingSvcWithErrReader([]testErrReader{
{Buf: []byte("ab"), Len: 3, Err: io.ErrUnexpectedEOF},
})
d := s3manager.NewDownloaderWithClient(s, func(d *s3manager.Downloader) {
d.Concurrency = 1
})
w := &aws.WriteAtBuffer{}
n, err := d.Download(w, &s3.GetObjectInput{
Bucket: aws.String("bucket"),
Key: aws.String("key"),
})
if err == nil {
t.Fatalf("expect error, got none")
}
if e, a := "unexpected EOF", err.Error(); !strings.Contains(a, e) {
t.Errorf("expect %q error message to be in %q", e, a)
}
if e, a := int64(2), n; e != a {
t.Errorf("expect %d bytes read, got %d", e, a)
}
expectCalls := []string{"GetObject"}
if e, a := expectCalls, *names; !reflect.DeepEqual(e, a) {
t.Errorf("expect %v API calls, got %v", e, a)
}
if e, a := "ab", string(w.Bytes()); e != a {
t.Errorf("expect %q response, got %q", e, a)
}
}
func TestDownloadWithContextCanceled(t *testing.T) {
d := s3manager.NewDownloader(unit.Session)
params := s3.GetObjectInput{
Bucket: aws.String("Bucket"),
Key: aws.String("Key"),
}
ctx := &awstesting.FakeContext{DoneCh: make(chan struct{})}
ctx.Error = fmt.Errorf("context canceled")
close(ctx.DoneCh)
w := &aws.WriteAtBuffer{}
_, err := d.DownloadWithContext(ctx, w, &params)
if err == nil {
t.Fatalf("expected error, did not get one")
}
aerr := err.(awserr.Error)
if e, a := request.CanceledErrorCode, aerr.Code(); e != a {
t.Errorf("expected error code %q, got %q", e, a)
}
if e, a := "canceled", aerr.Message(); !strings.Contains(a, e) {
t.Errorf("expected error message to contain %q, but did not %q", e, a)
}
}
func TestDownload_WithRange(t *testing.T) {
s, names, ranges := dlLoggingSvc([]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9})
d := s3manager.NewDownloaderWithClient(s, func(d *s3manager.Downloader) {
d.Concurrency = 10 // should be ignored
d.PartSize = 1 // should be ignored
})
w := &aws.WriteAtBuffer{}
n, err := d.Download(w, &s3.GetObjectInput{
Bucket: aws.String("bucket"),
Key: aws.String("key"),
Range: aws.String("bytes=2-6"),
})
if err != nil {
t.Fatalf("expect no error, got %v", err)
}
if e, a := int64(5), n; e != a {
t.Errorf("expect %d bytes read, got %d", e, a)
}
expectCalls := []string{"GetObject"}
if e, a := expectCalls, *names; !reflect.DeepEqual(e, a) {
t.Errorf("expect %v API calls, got %v", e, a)
}
expectRngs := []string{"bytes=2-6"}
if e, a := expectRngs, *ranges; !reflect.DeepEqual(e, a) {
t.Errorf("expect %v ranges, got %v", e, a)
}
expectBytes := []byte{2, 3, 4, 5, 6}
if e, a := expectBytes, w.Bytes(); !reflect.DeepEqual(e, a) {
t.Errorf("expect %v bytes, got %v", e, a)
}
}
func TestDownload_WithFailure(t *testing.T) {
svc := s3.New(unit.Session)
svc.Handlers.Send.Clear()
first := true
svc.Handlers.Send.PushBack(func(r *request.Request) {
if first {
first = false
body := bytes.NewReader(make([]byte, s3manager.DefaultDownloadPartSize))
r.HTTPResponse = &http.Response{
StatusCode: http.StatusOK,
Status: http.StatusText(http.StatusOK),
ContentLength: int64(body.Len()),
Body: ioutil.NopCloser(body),
Header: http.Header{},
}
r.HTTPResponse.Header.Set("Content-Length", strconv.Itoa(body.Len()))
r.HTTPResponse.Header.Set("Content-Range",
fmt.Sprintf("bytes 0-%d/%d", body.Len()-1, body.Len()*10))
return
}
// Give a chance for the multipart chunks to be queued up
time.Sleep(1 * time.Second)
r.HTTPResponse = &http.Response{
Header: http.Header{},
Body: ioutil.NopCloser(&bytes.Buffer{}),
}
r.Error = awserr.New("ConnectionError", "some connection error", nil)
r.Retryable = aws.Bool(false)
})
start := time.Now()
d := s3manager.NewDownloaderWithClient(svc, func(d *s3manager.Downloader) {
d.Concurrency = 2
})
w := &aws.WriteAtBuffer{}
params := s3.GetObjectInput{
Bucket: aws.String("Bucket"),
Key: aws.String("Key"),
}
// Expect this request to exit quickly after failure
_, err := d.Download(w, &params)
if err == nil {
t.Fatalf("expect error, got none")
}
limit := start.Add(5 * time.Second)
dur := time.Now().Sub(start)
if time.Now().After(limit) {
t.Errorf("expect time to be less than %v, took %v", limit, dur)
}
}
type testErrReader struct {
Buf []byte
Err error
Len int64
off int
}
func (r *testErrReader) Read(p []byte) (int, error) {
to := len(r.Buf) - r.off
n := copy(p, r.Buf[r.off:to])
r.off += n
if n < len(p) {
return n, r.Err
}
return n, nil
}

View File

@@ -0,0 +1,26 @@
// Package s3manageriface provides an interface for the s3manager package
package s3manageriface
import (
"io"
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/service/s3"
"github.com/aws/aws-sdk-go/service/s3/s3manager"
)
// DownloaderAPI is the interface type for s3manager.Downloader.
type DownloaderAPI interface {
Download(io.WriterAt, *s3.GetObjectInput, ...func(*s3manager.Downloader)) (int64, error)
DownloadWithContext(aws.Context, io.WriterAt, *s3.GetObjectInput, ...func(*s3manager.Downloader)) (int64, error)
}
var _ DownloaderAPI = (*s3manager.Downloader)(nil)
// UploaderAPI is the interface type for s3manager.Uploader.
type UploaderAPI interface {
Upload(*s3manager.UploadInput, ...func(*s3manager.Uploader)) (*s3manager.UploadOutput, error)
UploadWithContext(aws.Context, *s3manager.UploadInput, ...func(*s3manager.Uploader)) (*s3manager.UploadOutput, error)
}
var _ UploaderAPI = (*s3manager.Uploader)(nil)

View File

@@ -0,0 +1,4 @@
package s3manager_test
var buf12MB = make([]byte, 1024*1024*12)
var buf2MB = make([]byte, 1024*1024*2)

View File

@@ -0,0 +1,802 @@
package s3manager
import (
"bytes"
"fmt"
"io"
"sort"
"sync"
"time"
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/awserr"
"github.com/aws/aws-sdk-go/aws/awsutil"
"github.com/aws/aws-sdk-go/aws/client"
"github.com/aws/aws-sdk-go/aws/request"
"github.com/aws/aws-sdk-go/service/s3"
"github.com/aws/aws-sdk-go/service/s3/s3iface"
)
// MaxUploadParts is the maximum allowed number of parts in a multi-part upload
// on Amazon S3.
const MaxUploadParts = 10000
// MinUploadPartSize is the minimum allowed part size when uploading a part to
// Amazon S3.
const MinUploadPartSize int64 = 1024 * 1024 * 5
// DefaultUploadPartSize is the default part size to buffer chunks of a
// payload into.
const DefaultUploadPartSize = MinUploadPartSize
// DefaultUploadConcurrency is the default number of goroutines to spin up when
// using Upload().
const DefaultUploadConcurrency = 5
// A MultiUploadFailure wraps a failed S3 multipart upload. An error returned
// will satisfy this interface when a multi part upload failed to upload all
// chucks to S3. In the case of a failure the UploadID is needed to operate on
// the chunks, if any, which were uploaded.
//
// Example:
//
// u := s3manager.NewUploader(opts)
// output, err := u.upload(input)
// if err != nil {
// if multierr, ok := err.(s3manager.MultiUploadFailure); ok {
// // Process error and its associated uploadID
// fmt.Println("Error:", multierr.Code(), multierr.Message(), multierr.UploadID())
// } else {
// // Process error generically
// fmt.Println("Error:", err.Error())
// }
// }
//
type MultiUploadFailure interface {
awserr.Error
// Returns the upload id for the S3 multipart upload that failed.
UploadID() string
}
// So that the Error interface type can be included as an anonymous field
// in the multiUploadError struct and not conflict with the error.Error() method.
type awsError awserr.Error
// A multiUploadError wraps the upload ID of a failed s3 multipart upload.
// Composed of BaseError for code, message, and original error
//
// Should be used for an error that occurred failing a S3 multipart upload,
// and a upload ID is available. If an uploadID is not available a more relevant
type multiUploadError struct {
awsError
// ID for multipart upload which failed.
uploadID string
}
// Error returns the string representation of the error.
//
// See apierr.BaseError ErrorWithExtra for output format
//
// Satisfies the error interface.
func (m multiUploadError) Error() string {
extra := fmt.Sprintf("upload id: %s", m.uploadID)
return awserr.SprintError(m.Code(), m.Message(), extra, m.OrigErr())
}
// String returns the string representation of the error.
// Alias for Error to satisfy the stringer interface.
func (m multiUploadError) String() string {
return m.Error()
}
// UploadID returns the id of the S3 upload which failed.
func (m multiUploadError) UploadID() string {
return m.uploadID
}
// UploadInput contains all input for upload requests to Amazon S3.
type UploadInput struct {
// The canned ACL to apply to the object.
ACL *string `location:"header" locationName:"x-amz-acl" type:"string"`
Bucket *string `location:"uri" locationName:"Bucket" type:"string" required:"true"`
// Specifies caching behavior along the request/reply chain.
CacheControl *string `location:"header" locationName:"Cache-Control" type:"string"`
// Specifies presentational information for the object.
ContentDisposition *string `location:"header" locationName:"Content-Disposition" type:"string"`
// Specifies what content encodings have been applied to the object and thus
// what decoding mechanisms must be applied to obtain the media-type referenced
// by the Content-Type header field.
ContentEncoding *string `location:"header" locationName:"Content-Encoding" type:"string"`
// The language the content is in.
ContentLanguage *string `location:"header" locationName:"Content-Language" type:"string"`
// The base64-encoded 128-bit MD5 digest of the part data.
ContentMD5 *string `location:"header" locationName:"Content-MD5" type:"string"`
// A standard MIME type describing the format of the object data.
ContentType *string `location:"header" locationName:"Content-Type" type:"string"`
// The date and time at which the object is no longer cacheable.
Expires *time.Time `location:"header" locationName:"Expires" type:"timestamp" timestampFormat:"rfc822"`
// Gives the grantee READ, READ_ACP, and WRITE_ACP permissions on the object.
GrantFullControl *string `location:"header" locationName:"x-amz-grant-full-control" type:"string"`
// Allows grantee to read the object data and its metadata.
GrantRead *string `location:"header" locationName:"x-amz-grant-read" type:"string"`
// Allows grantee to read the object ACL.
GrantReadACP *string `location:"header" locationName:"x-amz-grant-read-acp" type:"string"`
// Allows grantee to write the ACL for the applicable object.
GrantWriteACP *string `location:"header" locationName:"x-amz-grant-write-acp" type:"string"`
Key *string `location:"uri" locationName:"Key" type:"string" required:"true"`
// A map of metadata to store with the object in S3.
Metadata map[string]*string `location:"headers" locationName:"x-amz-meta-" type:"map"`
// Confirms that the requester knows that she or he will be charged for the
// request. Bucket owners need not specify this parameter in their requests.
// Documentation on downloading objects from requester pays buckets can be found
// at http://docs.aws.amazon.com/AmazonS3/latest/dev/ObjectsinRequesterPaysBuckets.html
RequestPayer *string `location:"header" locationName:"x-amz-request-payer" type:"string"`
// Specifies the algorithm to use to when encrypting the object (e.g., AES256,
// aws:kms).
SSECustomerAlgorithm *string `location:"header" locationName:"x-amz-server-side-encryption-customer-algorithm" type:"string"`
// Specifies the customer-provided encryption key for Amazon S3 to use in encrypting
// data. This value is used to store the object and then it is discarded; Amazon
// does not store the encryption key. The key must be appropriate for use with
// the algorithm specified in the x-amz-server-side-encryption-customer-algorithm
// header.
SSECustomerKey *string `location:"header" locationName:"x-amz-server-side-encryption-customer-key" type:"string"`
// Specifies the 128-bit MD5 digest of the encryption key according to RFC 1321.
// Amazon S3 uses this header for a message integrity check to ensure the encryption
// key was transmitted without error.
SSECustomerKeyMD5 *string `location:"header" locationName:"x-amz-server-side-encryption-customer-key-MD5" type:"string"`
// Specifies the AWS KMS key ID to use for object encryption. All GET and PUT
// requests for an object protected by AWS KMS will fail if not made via SSL
// or using SigV4. Documentation on configuring any of the officially supported
// AWS SDKs and CLI can be found at http://docs.aws.amazon.com/AmazonS3/latest/dev/UsingAWSSDK.html#specify-signature-version
SSEKMSKeyId *string `location:"header" locationName:"x-amz-server-side-encryption-aws-kms-key-id" type:"string"`
// The Server-side encryption algorithm used when storing this object in S3
// (e.g., AES256, aws:kms).
ServerSideEncryption *string `location:"header" locationName:"x-amz-server-side-encryption" type:"string"`
// The type of storage to use for the object. Defaults to 'STANDARD'.
StorageClass *string `location:"header" locationName:"x-amz-storage-class" type:"string"`
// The tag-set for the object. The tag-set must be encoded as URL Query parameters
Tagging *string `location:"header" locationName:"x-amz-tagging" type:"string"`
// If the bucket is configured as a website, redirects requests for this object
// to another object in the same bucket or to an external URL. Amazon S3 stores
// the value of this header in the object metadata.
WebsiteRedirectLocation *string `location:"header" locationName:"x-amz-website-redirect-location" type:"string"`
// The readable body payload to send to S3.
Body io.Reader
}
// UploadOutput represents a response from the Upload() call.
type UploadOutput struct {
// The URL where the object was uploaded to.
Location string
// The version of the object that was uploaded. Will only be populated if
// the S3 Bucket is versioned. If the bucket is not versioned this field
// will not be set.
VersionID *string
// The ID for a multipart upload to S3. In the case of an error the error
// can be cast to the MultiUploadFailure interface to extract the upload ID.
UploadID string
}
// WithUploaderRequestOptions appends to the Uploader's API request options.
func WithUploaderRequestOptions(opts ...request.Option) func(*Uploader) {
return func(u *Uploader) {
u.RequestOptions = append(u.RequestOptions, opts...)
}
}
// The Uploader structure that calls Upload(). It is safe to call Upload()
// on this structure for multiple objects and across concurrent goroutines.
// Mutating the Uploader's properties is not safe to be done concurrently.
type Uploader struct {
// The buffer size (in bytes) to use when buffering data into chunks and
// sending them as parts to S3. The minimum allowed part size is 5MB, and
// if this value is set to zero, the DefaultUploadPartSize value will be used.
PartSize int64
// The number of goroutines to spin up in parallel per call to Upload when
// sending parts. If this is set to zero, the DefaultUploadConcurrency value
// will be used.
//
// The concurrency pool is not shared between calls to Upload.
Concurrency int
// Setting this value to true will cause the SDK to avoid calling
// AbortMultipartUpload on a failure, leaving all successfully uploaded
// parts on S3 for manual recovery.
//
// Note that storing parts of an incomplete multipart upload counts towards
// space usage on S3 and will add additional costs if not cleaned up.
LeavePartsOnError bool
// MaxUploadParts is the max number of parts which will be uploaded to S3.
// Will be used to calculate the partsize of the object to be uploaded.
// E.g: 5GB file, with MaxUploadParts set to 100, will upload the file
// as 100, 50MB parts.
// With a limited of s3.MaxUploadParts (10,000 parts).
//
// Defaults to package const's MaxUploadParts value.
MaxUploadParts int
// The client to use when uploading to S3.
S3 s3iface.S3API
// List of request options that will be passed down to individual API
// operation requests made by the uploader.
RequestOptions []request.Option
}
// NewUploader creates a new Uploader instance to upload objects to S3. Pass In
// additional functional options to customize the uploader's behavior. Requires a
// client.ConfigProvider in order to create a S3 service client. The session.Session
// satisfies the client.ConfigProvider interface.
//
// Example:
// // The session the S3 Uploader will use
// sess := session.Must(session.NewSession())
//
// // Create an uploader with the session and default options
// uploader := s3manager.NewUploader(sess)
//
// // Create an uploader with the session and custom options
// uploader := s3manager.NewUploader(session, func(u *s3manager.Uploader) {
// u.PartSize = 64 * 1024 * 1024 // 64MB per part
// })
func NewUploader(c client.ConfigProvider, options ...func(*Uploader)) *Uploader {
u := &Uploader{
S3: s3.New(c),
PartSize: DefaultUploadPartSize,
Concurrency: DefaultUploadConcurrency,
LeavePartsOnError: false,
MaxUploadParts: MaxUploadParts,
}
for _, option := range options {
option(u)
}
return u
}
// NewUploaderWithClient creates a new Uploader instance to upload objects to S3. Pass in
// additional functional options to customize the uploader's behavior. Requires
// a S3 service client to make S3 API calls.
//
// Example:
// // The session the S3 Uploader will use
// sess := session.Must(session.NewSession())
//
// // S3 service client the Upload manager will use.
// s3Svc := s3.New(sess)
//
// // Create an uploader with S3 client and default options
// uploader := s3manager.NewUploaderWithClient(s3Svc)
//
// // Create an uploader with S3 client and custom options
// uploader := s3manager.NewUploaderWithClient(s3Svc, func(u *s3manager.Uploader) {
// u.PartSize = 64 * 1024 * 1024 // 64MB per part
// })
func NewUploaderWithClient(svc s3iface.S3API, options ...func(*Uploader)) *Uploader {
u := &Uploader{
S3: svc,
PartSize: DefaultUploadPartSize,
Concurrency: DefaultUploadConcurrency,
LeavePartsOnError: false,
MaxUploadParts: MaxUploadParts,
}
for _, option := range options {
option(u)
}
return u
}
// Upload uploads an object to S3, intelligently buffering large files into
// smaller chunks and sending them in parallel across multiple goroutines. You
// can configure the buffer size and concurrency through the Uploader's parameters.
//
// Additional functional options can be provided to configure the individual
// upload. These options are copies of the Uploader instance Upload is called from.
// Modifying the options will not impact the original Uploader instance.
//
// Use the WithUploaderRequestOptions helper function to pass in request
// options that will be applied to all API operations made with this uploader.
//
// It is safe to call this method concurrently across goroutines.
//
// Example:
// // Upload input parameters
// upParams := &s3manager.UploadInput{
// Bucket: &bucketName,
// Key: &keyName,
// Body: file,
// }
//
// // Perform an upload.
// result, err := uploader.Upload(upParams)
//
// // Perform upload with options different than the those in the Uploader.
// result, err := uploader.Upload(upParams, func(u *s3manager.Uploader) {
// u.PartSize = 10 * 1024 * 1024 // 10MB part size
// u.LeavePartsOnError = true // Don't delete the parts if the upload fails.
// })
func (u Uploader) Upload(input *UploadInput, options ...func(*Uploader)) (*UploadOutput, error) {
return u.UploadWithContext(aws.BackgroundContext(), input, options...)
}
// UploadWithContext uploads an object to S3, intelligently buffering large
// files into smaller chunks and sending them in parallel across multiple
// goroutines. You can configure the buffer size and concurrency through the
// Uploader's parameters.
//
// UploadWithContext is the same as Upload with the additional support for
// Context input parameters. The Context must not be nil. A nil Context will
// cause a panic. Use the context to add deadlining, timeouts, etc. The
// UploadWithContext may create sub-contexts for individual underlying requests.
//
// Additional functional options can be provided to configure the individual
// upload. These options are copies of the Uploader instance Upload is called from.
// Modifying the options will not impact the original Uploader instance.
//
// Use the WithUploaderRequestOptions helper function to pass in request
// options that will be applied to all API operations made with this uploader.
//
// It is safe to call this method concurrently across goroutines.
func (u Uploader) UploadWithContext(ctx aws.Context, input *UploadInput, opts ...func(*Uploader)) (*UploadOutput, error) {
i := uploader{in: input, cfg: u, ctx: ctx}
for _, opt := range opts {
opt(&i.cfg)
}
i.cfg.RequestOptions = append(i.cfg.RequestOptions, request.WithAppendUserAgent("S3Manager"))
return i.upload()
}
// UploadWithIterator will upload a batched amount of objects to S3. This operation uses
// the iterator pattern to know which object to upload next. Since this is an interface this
// allows for custom defined functionality.
//
// Example:
// svc:= s3manager.NewUploader(sess)
//
// objects := []BatchUploadObject{
// {
// Object: &s3manager.UploadInput {
// Key: aws.String("key"),
// Bucket: aws.String("bucket"),
// },
// },
// }
//
// iter := &s3manager.UploadObjectsIterator{Objects: objects}
// if err := svc.UploadWithIterator(aws.BackgroundContext(), iter); err != nil {
// return err
// }
func (u Uploader) UploadWithIterator(ctx aws.Context, iter BatchUploadIterator, opts ...func(*Uploader)) error {
var errs []Error
for iter.Next() {
object := iter.UploadObject()
if _, err := u.UploadWithContext(ctx, object.Object, opts...); err != nil {
s3Err := Error{
OrigErr: err,
Bucket: object.Object.Bucket,
Key: object.Object.Key,
}
errs = append(errs, s3Err)
}
if object.After == nil {
continue
}
if err := object.After(); err != nil {
s3Err := Error{
OrigErr: err,
Bucket: object.Object.Bucket,
Key: object.Object.Key,
}
errs = append(errs, s3Err)
}
}
if len(errs) > 0 {
return NewBatchError("BatchedUploadIncomplete", "some objects have failed to upload.", errs)
}
return nil
}
// internal structure to manage an upload to S3.
type uploader struct {
ctx aws.Context
cfg Uploader
in *UploadInput
readerPos int64 // current reader position
totalSize int64 // set to -1 if the size is not known
bufferPool sync.Pool
}
// internal logic for deciding whether to upload a single part or use a
// multipart upload.
func (u *uploader) upload() (*UploadOutput, error) {
u.init()
if u.cfg.PartSize < MinUploadPartSize {
msg := fmt.Sprintf("part size must be at least %d bytes", MinUploadPartSize)
return nil, awserr.New("ConfigError", msg, nil)
}
// Do one read to determine if we have more than one part
reader, _, part, err := u.nextReader()
if err == io.EOF { // single part
return u.singlePart(reader)
} else if err != nil {
return nil, awserr.New("ReadRequestBody", "read upload data failed", err)
}
mu := multiuploader{uploader: u}
return mu.upload(reader, part)
}
// init will initialize all default options.
func (u *uploader) init() {
if u.cfg.Concurrency == 0 {
u.cfg.Concurrency = DefaultUploadConcurrency
}
if u.cfg.PartSize == 0 {
u.cfg.PartSize = DefaultUploadPartSize
}
if u.cfg.MaxUploadParts == 0 {
u.cfg.MaxUploadParts = MaxUploadParts
}
u.bufferPool = sync.Pool{
New: func() interface{} { return make([]byte, u.cfg.PartSize) },
}
// Try to get the total size for some optimizations
u.initSize()
}
// initSize tries to detect the total stream size, setting u.totalSize. If
// the size is not known, totalSize is set to -1.
func (u *uploader) initSize() {
u.totalSize = -1
switch r := u.in.Body.(type) {
case io.Seeker:
n, err := aws.SeekerLen(r)
if err != nil {
return
}
u.totalSize = n
// Try to adjust partSize if it is too small and account for
// integer division truncation.
if u.totalSize/u.cfg.PartSize >= int64(u.cfg.MaxUploadParts) {
// Add one to the part size to account for remainders
// during the size calculation. e.g odd number of bytes.
u.cfg.PartSize = (u.totalSize / int64(u.cfg.MaxUploadParts)) + 1
}
}
}
// nextReader returns a seekable reader representing the next packet of data.
// This operation increases the shared u.readerPos counter, but note that it
// does not need to be wrapped in a mutex because nextReader is only called
// from the main thread.
func (u *uploader) nextReader() (io.ReadSeeker, int, []byte, error) {
type readerAtSeeker interface {
io.ReaderAt
io.ReadSeeker
}
switch r := u.in.Body.(type) {
case readerAtSeeker:
var err error
n := u.cfg.PartSize
if u.totalSize >= 0 {
bytesLeft := u.totalSize - u.readerPos
if bytesLeft <= u.cfg.PartSize {
err = io.EOF
n = bytesLeft
}
}
reader := io.NewSectionReader(r, u.readerPos, n)
u.readerPos += n
return reader, int(n), nil, err
default:
part := u.bufferPool.Get().([]byte)
n, err := readFillBuf(r, part)
u.readerPos += int64(n)
return bytes.NewReader(part[0:n]), n, part, err
}
}
func readFillBuf(r io.Reader, b []byte) (offset int, err error) {
for offset < len(b) && err == nil {
var n int
n, err = r.Read(b[offset:])
offset += n
}
return offset, err
}
// singlePart contains upload logic for uploading a single chunk via
// a regular PutObject request. Multipart requests require at least two
// parts, or at least 5MB of data.
func (u *uploader) singlePart(buf io.ReadSeeker) (*UploadOutput, error) {
params := &s3.PutObjectInput{}
awsutil.Copy(params, u.in)
params.Body = buf
// Need to use request form because URL generated in request is
// used in return.
req, out := u.cfg.S3.PutObjectRequest(params)
req.SetContext(u.ctx)
req.ApplyOptions(u.cfg.RequestOptions...)
if err := req.Send(); err != nil {
return nil, err
}
url := req.HTTPRequest.URL.String()
return &UploadOutput{
Location: url,
VersionID: out.VersionId,
}, nil
}
// internal structure to manage a specific multipart upload to S3.
type multiuploader struct {
*uploader
wg sync.WaitGroup
m sync.Mutex
err error
uploadID string
parts completedParts
}
// keeps track of a single chunk of data being sent to S3.
type chunk struct {
buf io.ReadSeeker
part []byte
num int64
}
// completedParts is a wrapper to make parts sortable by their part number,
// since S3 required this list to be sent in sorted order.
type completedParts []*s3.CompletedPart
func (a completedParts) Len() int { return len(a) }
func (a completedParts) Swap(i, j int) { a[i], a[j] = a[j], a[i] }
func (a completedParts) Less(i, j int) bool { return *a[i].PartNumber < *a[j].PartNumber }
// upload will perform a multipart upload using the firstBuf buffer containing
// the first chunk of data.
func (u *multiuploader) upload(firstBuf io.ReadSeeker, firstPart []byte) (*UploadOutput, error) {
params := &s3.CreateMultipartUploadInput{}
awsutil.Copy(params, u.in)
// Create the multipart
resp, err := u.cfg.S3.CreateMultipartUploadWithContext(u.ctx, params, u.cfg.RequestOptions...)
if err != nil {
return nil, err
}
u.uploadID = *resp.UploadId
// Create the workers
ch := make(chan chunk, u.cfg.Concurrency)
for i := 0; i < u.cfg.Concurrency; i++ {
u.wg.Add(1)
go u.readChunk(ch)
}
// Send part 1 to the workers
var num int64 = 1
ch <- chunk{buf: firstBuf, part: firstPart, num: num}
// Read and queue the rest of the parts
for u.geterr() == nil && err == nil {
num++
// This upload exceeded maximum number of supported parts, error now.
if num > int64(u.cfg.MaxUploadParts) || num > int64(MaxUploadParts) {
var msg string
if num > int64(u.cfg.MaxUploadParts) {
msg = fmt.Sprintf("exceeded total allowed configured MaxUploadParts (%d). Adjust PartSize to fit in this limit",
u.cfg.MaxUploadParts)
} else {
msg = fmt.Sprintf("exceeded total allowed S3 limit MaxUploadParts (%d). Adjust PartSize to fit in this limit",
MaxUploadParts)
}
u.seterr(awserr.New("TotalPartsExceeded", msg, nil))
break
}
var reader io.ReadSeeker
var nextChunkLen int
var part []byte
reader, nextChunkLen, part, err = u.nextReader()
if err != nil && err != io.EOF {
u.seterr(awserr.New(
"ReadRequestBody",
"read multipart upload data failed",
err))
break
}
if nextChunkLen == 0 {
// No need to upload empty part, if file was empty to start
// with empty single part would of been created and never
// started multipart upload.
break
}
ch <- chunk{buf: reader, part: part, num: num}
}
// Close the channel, wait for workers, and complete upload
close(ch)
u.wg.Wait()
complete := u.complete()
if err := u.geterr(); err != nil {
return nil, &multiUploadError{
awsError: awserr.New(
"MultipartUpload",
"upload multipart failed",
err),
uploadID: u.uploadID,
}
}
return &UploadOutput{
Location: aws.StringValue(complete.Location),
VersionID: complete.VersionId,
UploadID: u.uploadID,
}, nil
}
// readChunk runs in worker goroutines to pull chunks off of the ch channel
// and send() them as UploadPart requests.
func (u *multiuploader) readChunk(ch chan chunk) {
defer u.wg.Done()
for {
data, ok := <-ch
if !ok {
break
}
if u.geterr() == nil {
if err := u.send(data); err != nil {
u.seterr(err)
}
}
}
}
// send performs an UploadPart request and keeps track of the completed
// part information.
func (u *multiuploader) send(c chunk) error {
params := &s3.UploadPartInput{
Bucket: u.in.Bucket,
Key: u.in.Key,
Body: c.buf,
UploadId: &u.uploadID,
SSECustomerAlgorithm: u.in.SSECustomerAlgorithm,
SSECustomerKey: u.in.SSECustomerKey,
PartNumber: &c.num,
}
resp, err := u.cfg.S3.UploadPartWithContext(u.ctx, params, u.cfg.RequestOptions...)
// put the byte array back into the pool to conserve memory
u.bufferPool.Put(c.part)
if err != nil {
return err
}
n := c.num
completed := &s3.CompletedPart{ETag: resp.ETag, PartNumber: &n}
u.m.Lock()
u.parts = append(u.parts, completed)
u.m.Unlock()
return nil
}
// geterr is a thread-safe getter for the error object
func (u *multiuploader) geterr() error {
u.m.Lock()
defer u.m.Unlock()
return u.err
}
// seterr is a thread-safe setter for the error object
func (u *multiuploader) seterr(e error) {
u.m.Lock()
defer u.m.Unlock()
u.err = e
}
// fail will abort the multipart unless LeavePartsOnError is set to true.
func (u *multiuploader) fail() {
if u.cfg.LeavePartsOnError {
return
}
params := &s3.AbortMultipartUploadInput{
Bucket: u.in.Bucket,
Key: u.in.Key,
UploadId: &u.uploadID,
}
_, err := u.cfg.S3.AbortMultipartUploadWithContext(u.ctx, params, u.cfg.RequestOptions...)
if err != nil {
logMessage(u.cfg.S3, aws.LogDebug, fmt.Sprintf("failed to abort multipart upload, %v", err))
}
}
// complete successfully completes a multipart upload and returns the response.
func (u *multiuploader) complete() *s3.CompleteMultipartUploadOutput {
if u.geterr() != nil {
u.fail()
return nil
}
// Parts must be sorted in PartNumber order.
sort.Sort(u.parts)
params := &s3.CompleteMultipartUploadInput{
Bucket: u.in.Bucket,
Key: u.in.Key,
UploadId: &u.uploadID,
MultipartUpload: &s3.CompletedMultipartUpload{Parts: u.parts},
}
resp, err := u.cfg.S3.CompleteMultipartUploadWithContext(u.ctx, params, u.cfg.RequestOptions...)
if err != nil {
u.seterr(err)
u.fail()
}
return resp
}

File diff suppressed because it is too large Load Diff

97
vendor/github.com/aws/aws-sdk-go/service/s3/service.go generated vendored Normal file
View File

@@ -0,0 +1,97 @@
// Code generated by private/model/cli/gen-api/main.go. DO NOT EDIT.
package s3
import (
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/client"
"github.com/aws/aws-sdk-go/aws/client/metadata"
"github.com/aws/aws-sdk-go/aws/request"
"github.com/aws/aws-sdk-go/aws/signer/v4"
"github.com/aws/aws-sdk-go/private/protocol/restxml"
)
// S3 provides the API operation methods for making requests to
// Amazon Simple Storage Service. See this package's package overview docs
// for details on the service.
//
// S3 methods are safe to use concurrently. It is not safe to
// modify mutate any of the struct's properties though.
type S3 struct {
*client.Client
}
// Used for custom client initialization logic
var initClient func(*client.Client)
// Used for custom request initialization logic
var initRequest func(*request.Request)
// Service information constants
const (
ServiceName = "s3" // Name of service.
EndpointsID = ServiceName // ID to lookup a service endpoint with.
ServiceID = "S3" // ServiceID is a unique identifer of a specific service.
)
// New creates a new instance of the S3 client with a session.
// If additional configuration is needed for the client instance use the optional
// aws.Config parameter to add your extra config.
//
// Example:
// // Create a S3 client from just a session.
// svc := s3.New(mySession)
//
// // Create a S3 client with additional configuration
// svc := s3.New(mySession, aws.NewConfig().WithRegion("us-west-2"))
func New(p client.ConfigProvider, cfgs ...*aws.Config) *S3 {
c := p.ClientConfig(EndpointsID, cfgs...)
return newClient(*c.Config, c.Handlers, c.Endpoint, c.SigningRegion, c.SigningName)
}
// newClient creates, initializes and returns a new service client instance.
func newClient(cfg aws.Config, handlers request.Handlers, endpoint, signingRegion, signingName string) *S3 {
svc := &S3{
Client: client.New(
cfg,
metadata.ClientInfo{
ServiceName: ServiceName,
ServiceID: ServiceID,
SigningName: signingName,
SigningRegion: signingRegion,
Endpoint: endpoint,
APIVersion: "2006-03-01",
},
handlers,
),
}
// Handlers
svc.Handlers.Sign.PushBackNamed(v4.SignRequestHandler)
svc.Handlers.Build.PushBackNamed(restxml.BuildHandler)
svc.Handlers.Unmarshal.PushBackNamed(restxml.UnmarshalHandler)
svc.Handlers.UnmarshalMeta.PushBackNamed(restxml.UnmarshalMetaHandler)
svc.Handlers.UnmarshalError.PushBackNamed(restxml.UnmarshalErrorHandler)
svc.Handlers.UnmarshalStream.PushBackNamed(restxml.UnmarshalHandler)
// Run custom client initialization if present
if initClient != nil {
initClient(svc.Client)
}
return svc
}
// newRequest creates a new request for a S3 operation and runs any
// custom request initialization.
func (c *S3) newRequest(op *request.Operation, params, data interface{}) *request.Request {
req := c.NewRequest(op, params, data)
// Run custom request initialization if present
if initRequest != nil {
initRequest(req)
}
return req
}

54
vendor/github.com/aws/aws-sdk-go/service/s3/sse.go generated vendored Normal file
View File

@@ -0,0 +1,54 @@
package s3
import (
"crypto/md5"
"encoding/base64"
"github.com/aws/aws-sdk-go/aws/awserr"
"github.com/aws/aws-sdk-go/aws/request"
)
var errSSERequiresSSL = awserr.New("ConfigError", "cannot send SSE keys over HTTP.", nil)
func validateSSERequiresSSL(r *request.Request) {
if r.HTTPRequest.URL.Scheme == "https" {
return
}
if iface, ok := r.Params.(sseCustomerKeyGetter); ok {
if len(iface.getSSECustomerKey()) > 0 {
r.Error = errSSERequiresSSL
return
}
}
if iface, ok := r.Params.(copySourceSSECustomerKeyGetter); ok {
if len(iface.getCopySourceSSECustomerKey()) > 0 {
r.Error = errSSERequiresSSL
return
}
}
}
func computeSSEKeys(r *request.Request) {
headers := []string{
"x-amz-server-side-encryption-customer-key",
"x-amz-copy-source-server-side-encryption-customer-key",
}
for _, h := range headers {
md5h := h + "-md5"
if key := r.HTTPRequest.Header.Get(h); key != "" {
// Base64-encode the value
b64v := base64.StdEncoding.EncodeToString([]byte(key))
r.HTTPRequest.Header.Set(h, b64v)
// Add MD5 if it wasn't computed
if r.HTTPRequest.Header.Get(md5h) == "" {
sum := md5.Sum([]byte(key))
b64sum := base64.StdEncoding.EncodeToString(sum[:])
r.HTTPRequest.Header.Set(md5h, b64sum)
}
}
}
}

111
vendor/github.com/aws/aws-sdk-go/service/s3/sse_test.go generated vendored Normal file
View File

@@ -0,0 +1,111 @@
package s3_test
import (
"strings"
"testing"
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/awserr"
"github.com/aws/aws-sdk-go/awstesting/unit"
"github.com/aws/aws-sdk-go/service/s3"
)
func TestSSECustomerKeyOverHTTPError(t *testing.T) {
s := s3.New(unit.Session, &aws.Config{DisableSSL: aws.Bool(true)})
req, _ := s.CopyObjectRequest(&s3.CopyObjectInput{
Bucket: aws.String("bucket"),
CopySource: aws.String("bucket/source"),
Key: aws.String("dest"),
SSECustomerKey: aws.String("key"),
})
err := req.Build()
if err == nil {
t.Error("expected an error")
}
if e, a := "ConfigError", err.(awserr.Error).Code(); e != a {
t.Errorf("expected %s, but received %s", e, a)
}
if !strings.Contains(err.(awserr.Error).Message(), "cannot send SSE keys over HTTP") {
t.Errorf("expected error to contain 'cannot send SSE keys over HTTP', but received %s", err.(awserr.Error).Message())
}
}
func TestCopySourceSSECustomerKeyOverHTTPError(t *testing.T) {
s := s3.New(unit.Session, &aws.Config{DisableSSL: aws.Bool(true)})
req, _ := s.CopyObjectRequest(&s3.CopyObjectInput{
Bucket: aws.String("bucket"),
CopySource: aws.String("bucket/source"),
Key: aws.String("dest"),
CopySourceSSECustomerKey: aws.String("key"),
})
err := req.Build()
if err == nil {
t.Error("expected an error")
}
if e, a := "ConfigError", err.(awserr.Error).Code(); e != a {
t.Errorf("expected %s, but received %s", e, a)
}
if !strings.Contains(err.(awserr.Error).Message(), "cannot send SSE keys over HTTP") {
t.Errorf("expected error to contain 'cannot send SSE keys over HTTP', but received %s", err.(awserr.Error).Message())
}
}
func TestComputeSSEKeys(t *testing.T) {
s := s3.New(unit.Session)
req, _ := s.CopyObjectRequest(&s3.CopyObjectInput{
Bucket: aws.String("bucket"),
CopySource: aws.String("bucket/source"),
Key: aws.String("dest"),
SSECustomerKey: aws.String("key"),
CopySourceSSECustomerKey: aws.String("key"),
})
err := req.Build()
if err != nil {
t.Errorf("expected no error, but received %v", err)
}
if e, a := "a2V5", req.HTTPRequest.Header.Get("x-amz-server-side-encryption-customer-key"); e != a {
t.Errorf("expected %s, but received %s", e, a)
}
if e, a := "a2V5", req.HTTPRequest.Header.Get("x-amz-copy-source-server-side-encryption-customer-key"); e != a {
t.Errorf("expected %s, but received %s", e, a)
}
if e, a := "PG4LipwVIkqCKLmpjKFTHQ==", req.HTTPRequest.Header.Get("x-amz-server-side-encryption-customer-key-md5"); e != a {
t.Errorf("expected %s, but received %s", e, a)
}
if e, a := "PG4LipwVIkqCKLmpjKFTHQ==", req.HTTPRequest.Header.Get("x-amz-copy-source-server-side-encryption-customer-key-md5"); e != a {
t.Errorf("expected %s, but received %s", e, a)
}
}
func TestComputeSSEKeysShortcircuit(t *testing.T) {
s := s3.New(unit.Session)
req, _ := s.CopyObjectRequest(&s3.CopyObjectInput{
Bucket: aws.String("bucket"),
CopySource: aws.String("bucket/source"),
Key: aws.String("dest"),
SSECustomerKey: aws.String("key"),
CopySourceSSECustomerKey: aws.String("key"),
SSECustomerKeyMD5: aws.String("MD5"),
CopySourceSSECustomerKeyMD5: aws.String("MD5"),
})
err := req.Build()
if err != nil {
t.Errorf("expected no error, but received %v", err)
}
if e, a := "a2V5", req.HTTPRequest.Header.Get("x-amz-server-side-encryption-customer-key"); e != a {
t.Errorf("expected %s, but received %s", e, a)
}
if e, a := "a2V5", req.HTTPRequest.Header.Get("x-amz-copy-source-server-side-encryption-customer-key"); e != a {
t.Errorf("expected %s, but received %s", e, a)
}
if e, a := "MD5", req.HTTPRequest.Header.Get("x-amz-server-side-encryption-customer-key-md5"); e != a {
t.Errorf("expected %s, but received %s", e, a)
}
if e, a := "MD5", req.HTTPRequest.Header.Get("x-amz-copy-source-server-side-encryption-customer-key-md5"); e != a {
t.Errorf("expected %s, but received %s", e, a)
}
}

View File

@@ -0,0 +1,36 @@
package s3
import (
"bytes"
"io/ioutil"
"net/http"
"github.com/aws/aws-sdk-go/aws/awserr"
"github.com/aws/aws-sdk-go/aws/request"
"github.com/aws/aws-sdk-go/internal/sdkio"
)
func copyMultipartStatusOKUnmarhsalError(r *request.Request) {
b, err := ioutil.ReadAll(r.HTTPResponse.Body)
if err != nil {
r.Error = awserr.New("SerializationError", "unable to read response body", err)
return
}
body := bytes.NewReader(b)
r.HTTPResponse.Body = ioutil.NopCloser(body)
defer body.Seek(0, sdkio.SeekStart)
if body.Len() == 0 {
// If there is no body don't attempt to parse the body.
return
}
unmarshalError(r)
if err, ok := r.Error.(awserr.Error); ok && err != nil {
if err.Code() == "SerializationError" {
r.Error = nil
return
}
r.HTTPResponse.StatusCode = http.StatusServiceUnavailable
}
}

View File

@@ -0,0 +1,166 @@
package s3_test
import (
"fmt"
"net/http"
"net/http/httptest"
"testing"
"time"
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/awserr"
"github.com/aws/aws-sdk-go/awstesting/unit"
"github.com/aws/aws-sdk-go/service/s3"
)
const errMsg = `<Error><Code>ErrorCode</Code><Message>message body</Message><RequestId>requestID</RequestId><HostId>hostID=</HostId></Error>`
var lastModifiedTime = time.Date(2009, 11, 23, 0, 0, 0, 0, time.UTC)
func TestCopyObjectNoError(t *testing.T) {
const successMsg = `
<?xml version="1.0" encoding="UTF-8"?>
<CopyObjectResult xmlns="http://s3.amazonaws.com/doc/2006-03-01/"><LastModified>2009-11-23T0:00:00Z</LastModified><ETag>&quot;1da64c7f13d1e8dbeaea40b905fd586c&quot;</ETag></CopyObjectResult>`
res, err := newCopyTestSvc(successMsg).CopyObject(&s3.CopyObjectInput{
Bucket: aws.String("bucketname"),
CopySource: aws.String("bucketname/exists.txt"),
Key: aws.String("destination.txt"),
})
if err != nil {
t.Errorf("expected no error, but received %v", err)
}
if e, a := fmt.Sprintf(`%q`, "1da64c7f13d1e8dbeaea40b905fd586c"), *res.CopyObjectResult.ETag; e != a {
t.Errorf("expected %s, but received %s", e, a)
}
if e, a := lastModifiedTime, *res.CopyObjectResult.LastModified; !e.Equal(a) {
t.Errorf("expected %v, but received %v", e, a)
}
}
func TestCopyObjectError(t *testing.T) {
_, err := newCopyTestSvc(errMsg).CopyObject(&s3.CopyObjectInput{
Bucket: aws.String("bucketname"),
CopySource: aws.String("bucketname/doesnotexist.txt"),
Key: aws.String("destination.txt"),
})
if err == nil {
t.Error("expected error, but received none")
}
e := err.(awserr.Error)
if e, a := "ErrorCode", e.Code(); e != a {
t.Errorf("expected %s, but received %s", e, a)
}
if e, a := "message body", e.Message(); e != a {
t.Errorf("expected %s, but received %s", e, a)
}
}
func TestUploadPartCopySuccess(t *testing.T) {
const successMsg = `
<?xml version="1.0" encoding="UTF-8"?>
<UploadPartCopyResult xmlns="http://s3.amazonaws.com/doc/2006-03-01/"><LastModified>2009-11-23T0:00:00Z</LastModified><ETag>&quot;1da64c7f13d1e8dbeaea40b905fd586c&quot;</ETag></UploadPartCopyResult>`
res, err := newCopyTestSvc(successMsg).UploadPartCopy(&s3.UploadPartCopyInput{
Bucket: aws.String("bucketname"),
CopySource: aws.String("bucketname/doesnotexist.txt"),
Key: aws.String("destination.txt"),
PartNumber: aws.Int64(0),
UploadId: aws.String("uploadID"),
})
if err != nil {
t.Errorf("expected no error, but received %v", err)
}
if e, a := fmt.Sprintf(`%q`, "1da64c7f13d1e8dbeaea40b905fd586c"), *res.CopyPartResult.ETag; e != a {
t.Errorf("expected %s, but received %s", e, a)
}
if e, a := lastModifiedTime, *res.CopyPartResult.LastModified; !e.Equal(a) {
t.Errorf("expected %v, but received %v", e, a)
}
}
func TestUploadPartCopyError(t *testing.T) {
_, err := newCopyTestSvc(errMsg).UploadPartCopy(&s3.UploadPartCopyInput{
Bucket: aws.String("bucketname"),
CopySource: aws.String("bucketname/doesnotexist.txt"),
Key: aws.String("destination.txt"),
PartNumber: aws.Int64(0),
UploadId: aws.String("uploadID"),
})
if err == nil {
t.Error("expected an error")
}
e := err.(awserr.Error)
if e, a := "ErrorCode", e.Code(); e != a {
t.Errorf("expected %s, but received %s", e, a)
}
if e, a := "message body", e.Message(); e != a {
t.Errorf("expected %s, but received %s", e, a)
}
}
func TestCompleteMultipartUploadSuccess(t *testing.T) {
const successMsg = `
<?xml version="1.0" encoding="UTF-8"?>
<CompleteMultipartUploadResult xmlns="http://s3.amazonaws.com/doc/2006-03-01/"><Location>locationName</Location><Bucket>bucketName</Bucket><Key>keyName</Key><ETag>"etagVal"</ETag></CompleteMultipartUploadResult>`
res, err := newCopyTestSvc(successMsg).CompleteMultipartUpload(&s3.CompleteMultipartUploadInput{
Bucket: aws.String("bucketname"),
Key: aws.String("key"),
UploadId: aws.String("uploadID"),
})
if err != nil {
t.Errorf("expected no error, but received %v", err)
}
if e, a := `"etagVal"`, *res.ETag; e != a {
t.Errorf("expected %s, but received %s", e, a)
}
if e, a := "bucketName", *res.Bucket; e != a {
t.Errorf("expected %s, but received %s", e, a)
}
if e, a := "keyName", *res.Key; e != a {
t.Errorf("expected %s, but received %s", e, a)
}
if e, a := "locationName", *res.Location; e != a {
t.Errorf("expected %s, but received %s", e, a)
}
}
func TestCompleteMultipartUploadError(t *testing.T) {
_, err := newCopyTestSvc(errMsg).CompleteMultipartUpload(&s3.CompleteMultipartUploadInput{
Bucket: aws.String("bucketname"),
Key: aws.String("key"),
UploadId: aws.String("uploadID"),
})
if err == nil {
t.Error("expected an error")
}
e := err.(awserr.Error)
if e, a := "ErrorCode", e.Code(); e != a {
t.Errorf("expected %s, but received %s", e, a)
}
if e, a := "message body", e.Message(); e != a {
t.Errorf("expected %s, but received %s", e, a)
}
}
func newCopyTestSvc(errMsg string) *s3.S3 {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
http.Error(w, errMsg, http.StatusOK)
}))
return s3.New(unit.Session, aws.NewConfig().
WithEndpoint(server.URL).
WithDisableSSL(true).
WithMaxRetries(0).
WithS3ForcePathStyle(true))
}

View File

@@ -0,0 +1,178 @@
[
{
"Bucket": "bucket-name",
"ConfiguredAddressingStyle": "default",
"ExpectedUri": "https://bucket-name.s3.amazonaws.com",
"Region": "us-east-1",
"UseDualstack": false,
"UseS3Accelerate": false
},
{
"Bucket": "bucket-name",
"ConfiguredAddressingStyle": "default",
"ExpectedUri": "https://bucket-name.s3.us-west-1.amazonaws.com",
"Region": "us-west-1",
"UseDualstack": false,
"UseS3Accelerate": false
},
{
"Bucket": "bucket-with-number-1",
"ConfiguredAddressingStyle": "default",
"ExpectedUri": "https://bucket-with-number-1.s3.us-west-1.amazonaws.com",
"Region": "us-west-1",
"UseDualstack": false,
"UseS3Accelerate": false
},
{
"Bucket": "bucket-name",
"ConfiguredAddressingStyle": "default",
"ExpectedUri": "https://bucket-name.s3.cn-north-1.amazonaws.com.cn",
"Region": "cn-north-1",
"UseDualstack": false,
"UseS3Accelerate": false
},
{
"Bucket": "BucketName",
"ConfiguredAddressingStyle": "default",
"ExpectedUri": "https://s3.amazonaws.com/BucketName",
"Region": "us-east-1",
"UseDualstack": false,
"UseS3Accelerate": false
},
{
"Bucket": "BucketName",
"ConfiguredAddressingStyle": "default",
"ExpectedUri": "https://s3.amazonaws.com/BucketName",
"Region": "us-east-1",
"UseDualstack": false,
"UseS3Accelerate": false
},
{
"Bucket": "bucket_name",
"ConfiguredAddressingStyle": "default",
"ExpectedUri": "https://s3.us-west-1.amazonaws.com/bucket_name",
"Region": "us-west-1",
"UseDualstack": false,
"UseS3Accelerate": false
},
{
"Bucket": "bucket.name",
"ConfiguredAddressingStyle": "default",
"ExpectedUri": "https://s3.us-west-1.amazonaws.com/bucket.name",
"Region": "us-west-1",
"UseDualstack": false,
"UseS3Accelerate": false
},
{
"Bucket": "-bucket-name",
"ConfiguredAddressingStyle": "default",
"ExpectedUri": "https://s3.us-west-1.amazonaws.com/-bucket-name",
"Region": "us-west-1",
"UseDualstack": false,
"UseS3Accelerate": false
},
{
"Bucket": "bucket-name-",
"ConfiguredAddressingStyle": "default",
"ExpectedUri": "https://s3.us-west-1.amazonaws.com/bucket-name-",
"Region": "us-west-1",
"UseDualstack": false,
"UseS3Accelerate": false
},
{
"Bucket": "aa",
"ConfiguredAddressingStyle": "default",
"ExpectedUri": "https://s3.us-west-1.amazonaws.com/aa",
"Region": "us-west-1",
"UseDualstack": false,
"UseS3Accelerate": false
},
{
"Bucket": "aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa",
"ConfiguredAddressingStyle": "default",
"ExpectedUri": "https://s3.us-west-1.amazonaws.com/aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa",
"Region": "us-west-1",
"UseDualstack": false,
"UseS3Accelerate": false
},
{
"Bucket": "bucket-name",
"ConfiguredAddressingStyle": "default",
"ExpectedUri": "https://bucket-name.s3-accelerate.amazonaws.com",
"Region": "us-east-1",
"UseDualstack": false,
"UseS3Accelerate": true
},
{
"Bucket": "bucket-name",
"ConfiguredAddressingStyle": "default",
"ExpectedUri": "https://bucket-name.s3-accelerate.amazonaws.com",
"Region": "us-west-1",
"UseDualstack": false,
"UseS3Accelerate": true
},
{
"Bucket": "bucket-name",
"ConfiguredAddressingStyle": "default",
"ExpectedUri": "https://bucket-name.s3.dualstack.us-east-1.amazonaws.com",
"Region": "us-east-1",
"UseDualstack": true,
"UseS3Accelerate": false
},
{
"Bucket": "bucket-name",
"ConfiguredAddressingStyle": "default",
"ExpectedUri": "https://bucket-name.s3.dualstack.us-west-2.amazonaws.com",
"Region": "us-west-2",
"UseDualstack": true,
"UseS3Accelerate": false
},
{
"Bucket": "bucket.name",
"ConfiguredAddressingStyle": "default",
"ExpectedUri": "https://s3.dualstack.us-west-2.amazonaws.com/bucket.name",
"Region": "us-west-2",
"UseDualstack": true,
"UseS3Accelerate": false
},
{
"Bucket": "bucket-name",
"ConfiguredAddressingStyle": "default",
"ExpectedUri": "https://bucket-name.s3-accelerate.dualstack.amazonaws.com",
"Region": "us-east-1",
"UseDualstack": true,
"UseS3Accelerate": true
},
{
"Bucket": "bucket-name",
"ConfiguredAddressingStyle": "path",
"ExpectedUri": "https://s3.amazonaws.com/bucket-name",
"Region": "us-east-1",
"UseDualstack": false,
"UseS3Accelerate": false
},
{
"Bucket": "bucket-name",
"ConfiguredAddressingStyle": "path",
"ExpectedUri": "https://bucket-name.s3-accelerate.amazonaws.com",
"Region": "us-east-1",
"UseDualstack": false,
"UseS3Accelerate": true
},
{
"Bucket": "bucket-name",
"ConfiguredAddressingStyle": "path",
"ExpectedUri": "https://s3.dualstack.us-east-1.amazonaws.com/bucket-name",
"Region": "us-east-1",
"UseDualstack": true,
"UseS3Accelerate": false
},
{
"Bucket": "bucket-name",
"ConfiguredAddressingStyle": "path",
"ExpectedUri": "https://bucket-name.s3-accelerate.dualstack.amazonaws.com",
"Region": "us-east-1",
"UseDualstack": true,
"UseS3Accelerate": true
}
]

View File

@@ -0,0 +1,103 @@
package s3
import (
"encoding/xml"
"fmt"
"io"
"io/ioutil"
"net/http"
"strings"
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/awserr"
"github.com/aws/aws-sdk-go/aws/request"
)
type xmlErrorResponse struct {
XMLName xml.Name `xml:"Error"`
Code string `xml:"Code"`
Message string `xml:"Message"`
}
func unmarshalError(r *request.Request) {
defer r.HTTPResponse.Body.Close()
defer io.Copy(ioutil.Discard, r.HTTPResponse.Body)
hostID := r.HTTPResponse.Header.Get("X-Amz-Id-2")
// Bucket exists in a different region, and request needs
// to be made to the correct region.
if r.HTTPResponse.StatusCode == http.StatusMovedPermanently {
r.Error = requestFailure{
RequestFailure: awserr.NewRequestFailure(
awserr.New("BucketRegionError",
fmt.Sprintf("incorrect region, the bucket is not in '%s' region",
aws.StringValue(r.Config.Region)),
nil),
r.HTTPResponse.StatusCode,
r.RequestID,
),
hostID: hostID,
}
return
}
var errCode, errMsg string
// Attempt to parse error from body if it is known
resp := &xmlErrorResponse{}
err := xml.NewDecoder(r.HTTPResponse.Body).Decode(resp)
if err != nil && err != io.EOF {
errCode = "SerializationError"
errMsg = "failed to decode S3 XML error response"
} else {
errCode = resp.Code
errMsg = resp.Message
err = nil
}
// Fallback to status code converted to message if still no error code
if len(errCode) == 0 {
statusText := http.StatusText(r.HTTPResponse.StatusCode)
errCode = strings.Replace(statusText, " ", "", -1)
errMsg = statusText
}
r.Error = requestFailure{
RequestFailure: awserr.NewRequestFailure(
awserr.New(errCode, errMsg, err),
r.HTTPResponse.StatusCode,
r.RequestID,
),
hostID: hostID,
}
}
// A RequestFailure provides access to the S3 Request ID and Host ID values
// returned from API operation errors. Getting the error as a string will
// return the formated error with the same information as awserr.RequestFailure,
// while also adding the HostID value from the response.
type RequestFailure interface {
awserr.RequestFailure
// Host ID is the S3 Host ID needed for debug, and contacting support
HostID() string
}
type requestFailure struct {
awserr.RequestFailure
hostID string
}
func (r requestFailure) Error() string {
extra := fmt.Sprintf("status code: %d, request id: %s, host id: %s",
r.StatusCode(), r.RequestID(), r.hostID)
return awserr.SprintError(r.Code(), r.Message(), extra, r.OrigErr())
}
func (r requestFailure) String() string {
return r.Error()
}
func (r requestFailure) HostID() string {
return r.hostID
}

View File

@@ -0,0 +1,40 @@
package s3
import (
"net/http"
"testing"
"github.com/aws/aws-sdk-go/aws/request"
"github.com/aws/aws-sdk-go/awstesting"
)
func TestUnmarhsalErrorLeak(t *testing.T) {
req := &request.Request{
HTTPRequest: &http.Request{
Header: make(http.Header),
Body: &awstesting.ReadCloser{Size: 2048},
},
}
req.HTTPResponse = &http.Response{
Body: &awstesting.ReadCloser{Size: 2048},
Header: http.Header{
"X-Amzn-Requestid": []string{"1"},
},
StatusCode: http.StatusOK,
}
reader := req.HTTPResponse.Body.(*awstesting.ReadCloser)
unmarshalError(req)
if req.Error == nil {
t.Error("expected an error, but received none")
}
if !reader.Closed {
t.Error("expected reader to be closed")
}
if e, a := 0, reader.Size; e != a {
t.Errorf("expected %d, but received %d", e, a)
}
}

View File

@@ -0,0 +1,253 @@
package s3_test
import (
"bytes"
"fmt"
"io/ioutil"
"net/http"
"strings"
"testing"
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/awserr"
"github.com/aws/aws-sdk-go/aws/request"
"github.com/aws/aws-sdk-go/awstesting/unit"
"github.com/aws/aws-sdk-go/service/s3"
)
type testErrorCase struct {
RespFn func() *http.Response
ReqID, HostID string
Code, Msg string
WithoutStatusMsg bool
}
var testUnmarshalCases = []testErrorCase{
{
RespFn: func() *http.Response {
return &http.Response{
StatusCode: 301,
Header: http.Header{
"X-Amz-Request-Id": []string{"abc123"},
"X-Amz-Id-2": []string{"321cba"},
},
Body: ioutil.NopCloser(bytes.NewReader(nil)),
ContentLength: -1,
}
},
ReqID: "abc123",
HostID: "321cba",
Code: "BucketRegionError", Msg: "incorrect region, the bucket is not in 'mock-region' region",
},
{
RespFn: func() *http.Response {
return &http.Response{
StatusCode: 403,
Header: http.Header{
"X-Amz-Request-Id": []string{"abc123"},
"X-Amz-Id-2": []string{"321cba"},
},
Body: ioutil.NopCloser(bytes.NewReader(nil)),
ContentLength: 0,
}
},
ReqID: "abc123",
HostID: "321cba",
Code: "Forbidden", Msg: "Forbidden",
},
{
RespFn: func() *http.Response {
return &http.Response{
StatusCode: 400,
Header: http.Header{
"X-Amz-Request-Id": []string{"abc123"},
"X-Amz-Id-2": []string{"321cba"},
},
Body: ioutil.NopCloser(bytes.NewReader(nil)),
ContentLength: 0,
}
},
ReqID: "abc123",
HostID: "321cba",
Code: "BadRequest", Msg: "Bad Request",
},
{
RespFn: func() *http.Response {
return &http.Response{
StatusCode: 404,
Header: http.Header{
"X-Amz-Request-Id": []string{"abc123"},
"X-Amz-Id-2": []string{"321cba"},
},
Body: ioutil.NopCloser(bytes.NewReader(nil)),
ContentLength: 0,
}
},
ReqID: "abc123",
HostID: "321cba",
Code: "NotFound", Msg: "Not Found",
},
{
// SDK only reads request ID and host ID from the header. The values
// in message body are ignored.
RespFn: func() *http.Response {
body := `<Error><Code>SomeException</Code><Message>Exception message</Message><RequestId>ignored-request-id</RequestId><HostId>ignored-host-id</HostId></Error>`
return &http.Response{
StatusCode: 500,
Header: http.Header{
"X-Amz-Request-Id": []string{"taken-request-id"},
"X-Amz-Id-2": []string{"taken-host-id"},
},
Body: ioutil.NopCloser(strings.NewReader(body)),
ContentLength: int64(len(body)),
}
},
ReqID: "taken-request-id",
HostID: "taken-host-id",
Code: "SomeException", Msg: "Exception message",
},
{
RespFn: func() *http.Response {
return &http.Response{
StatusCode: 404,
Header: http.Header{
"X-Amz-Request-Id": []string{"abc123"},
"X-Amz-Id-2": []string{"321cba"},
},
Body: ioutil.NopCloser(bytes.NewReader(nil)),
ContentLength: -1,
}
},
ReqID: "abc123",
HostID: "321cba",
Code: "NotFound", Msg: "Not Found", WithoutStatusMsg: true,
},
{
RespFn: func() *http.Response {
return &http.Response{
StatusCode: 404,
Header: http.Header{
"X-Amz-Request-Id": []string{"abc123"},
"X-Amz-Id-2": []string{"321cba"},
},
Body: ioutil.NopCloser(bytes.NewReader(nil)),
ContentLength: -1,
}
},
ReqID: "abc123",
HostID: "321cba",
Code: "NotFound", Msg: "Not Found",
},
}
func TestUnmarshalError(t *testing.T) {
for i, c := range testUnmarshalCases {
s := s3.New(unit.Session)
s.Handlers.Send.Clear()
s.Handlers.Send.PushBack(func(r *request.Request) {
r.HTTPResponse = c.RespFn()
if !c.WithoutStatusMsg {
r.HTTPResponse.Status = fmt.Sprintf("%d%s",
r.HTTPResponse.StatusCode,
http.StatusText(r.HTTPResponse.StatusCode))
}
})
_, err := s.PutBucketAcl(&s3.PutBucketAclInput{
Bucket: aws.String("bucket"), ACL: aws.String("public-read"),
})
if err == nil {
t.Fatalf("%d, expected error, got nil", i)
}
if e, a := c.Code, err.(awserr.Error).Code(); e != a {
t.Errorf("%d, Code: expect %s, got %s", i, e, a)
}
if e, a := c.Msg, err.(awserr.Error).Message(); e != a {
t.Errorf("%d, Message: expect %s, got %s", i, e, a)
}
if e, a := c.ReqID, err.(awserr.RequestFailure).RequestID(); e != a {
t.Errorf("%d, RequestID: expect %s, got %s", i, e, a)
}
if e, a := c.HostID, err.(s3.RequestFailure).HostID(); e != a {
t.Errorf("%d, HostID: expect %s, got %s", i, e, a)
}
}
}
const completeMultiResp = `
163
<?xml version="1.0" encoding="UTF-8"?>
<CompleteMultipartUploadResult xmlns="http://s3.amazonaws.com/doc/2006-03-01/"><Location>https://bucket.s3-us-west-2.amazonaws.com/key</Location><Bucket>bucket</Bucket><Key>key</Key><ETag>&quot;a7d414b9133d6483d9a1c4e04e856e3b-2&quot;</ETag></CompleteMultipartUploadResult>
0
`
func Test200NoErrorUnmarshalError(t *testing.T) {
s := s3.New(unit.Session)
s.Handlers.Send.Clear()
s.Handlers.Send.PushBack(func(r *request.Request) {
r.HTTPResponse = &http.Response{
StatusCode: 200,
Header: http.Header{
"X-Amz-Request-Id": []string{"abc123"},
"X-Amz-Id-2": []string{"321cba"},
},
Body: ioutil.NopCloser(strings.NewReader(completeMultiResp)),
ContentLength: -1,
}
r.HTTPResponse.Status = http.StatusText(r.HTTPResponse.StatusCode)
})
_, err := s.CompleteMultipartUpload(&s3.CompleteMultipartUploadInput{
Bucket: aws.String("bucket"), Key: aws.String("key"),
UploadId: aws.String("id"),
MultipartUpload: &s3.CompletedMultipartUpload{Parts: []*s3.CompletedPart{
{ETag: aws.String("etag"), PartNumber: aws.Int64(1)},
}},
})
if err != nil {
t.Fatalf("expect no error, got %v", err)
}
}
const completeMultiErrResp = `<Error><Code>SomeException</Code><Message>Exception message</Message></Error>`
func Test200WithErrorUnmarshalError(t *testing.T) {
s := s3.New(unit.Session)
s.Handlers.Send.Clear()
s.Handlers.Send.PushBack(func(r *request.Request) {
r.HTTPResponse = &http.Response{
StatusCode: 200,
Header: http.Header{
"X-Amz-Request-Id": []string{"abc123"},
"X-Amz-Id-2": []string{"321cba"},
},
Body: ioutil.NopCloser(strings.NewReader(completeMultiErrResp)),
ContentLength: -1,
}
r.HTTPResponse.Status = http.StatusText(r.HTTPResponse.StatusCode)
})
_, err := s.CompleteMultipartUpload(&s3.CompleteMultipartUploadInput{
Bucket: aws.String("bucket"), Key: aws.String("key"),
UploadId: aws.String("id"),
MultipartUpload: &s3.CompletedMultipartUpload{Parts: []*s3.CompletedPart{
{ETag: aws.String("etag"), PartNumber: aws.Int64(1)},
}},
})
if err == nil {
t.Fatalf("expected error, got nil")
}
if e, a := "SomeException", err.(awserr.Error).Code(); e != a {
t.Errorf("Code: expect %s, got %s", e, a)
}
if e, a := "Exception message", err.(awserr.Error).Message(); e != a {
t.Errorf("Message: expect %s, got %s", e, a)
}
if e, a := "abc123", err.(s3.RequestFailure).RequestID(); e != a {
t.Errorf("RequestID: expect %s, got %s", e, a)
}
if e, a := "321cba", err.(s3.RequestFailure).HostID(); e != a {
t.Errorf("HostID: expect %s, got %s", e, a)
}
}

214
vendor/github.com/aws/aws-sdk-go/service/s3/waiters.go generated vendored Normal file
View File

@@ -0,0 +1,214 @@
// Code generated by private/model/cli/gen-api/main.go. DO NOT EDIT.
package s3
import (
"time"
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/request"
)
// WaitUntilBucketExists uses the Amazon S3 API operation
// HeadBucket to wait for a condition to be met before returning.
// If the condition is not met within the max attempt window, an error will
// be returned.
func (c *S3) WaitUntilBucketExists(input *HeadBucketInput) error {
return c.WaitUntilBucketExistsWithContext(aws.BackgroundContext(), input)
}
// WaitUntilBucketExistsWithContext is an extended version of WaitUntilBucketExists.
// With the support for passing in a context and options to configure the
// Waiter and the underlying request options.
//
// The context must be non-nil and will be used for request cancellation. If
// the context is nil a panic will occur. In the future the SDK may create
// sub-contexts for http.Requests. See https://golang.org/pkg/context/
// for more information on using Contexts.
func (c *S3) WaitUntilBucketExistsWithContext(ctx aws.Context, input *HeadBucketInput, opts ...request.WaiterOption) error {
w := request.Waiter{
Name: "WaitUntilBucketExists",
MaxAttempts: 20,
Delay: request.ConstantWaiterDelay(5 * time.Second),
Acceptors: []request.WaiterAcceptor{
{
State: request.SuccessWaiterState,
Matcher: request.StatusWaiterMatch,
Expected: 200,
},
{
State: request.SuccessWaiterState,
Matcher: request.StatusWaiterMatch,
Expected: 301,
},
{
State: request.SuccessWaiterState,
Matcher: request.StatusWaiterMatch,
Expected: 403,
},
{
State: request.RetryWaiterState,
Matcher: request.StatusWaiterMatch,
Expected: 404,
},
},
Logger: c.Config.Logger,
NewRequest: func(opts []request.Option) (*request.Request, error) {
var inCpy *HeadBucketInput
if input != nil {
tmp := *input
inCpy = &tmp
}
req, _ := c.HeadBucketRequest(inCpy)
req.SetContext(ctx)
req.ApplyOptions(opts...)
return req, nil
},
}
w.ApplyOptions(opts...)
return w.WaitWithContext(ctx)
}
// WaitUntilBucketNotExists uses the Amazon S3 API operation
// HeadBucket to wait for a condition to be met before returning.
// If the condition is not met within the max attempt window, an error will
// be returned.
func (c *S3) WaitUntilBucketNotExists(input *HeadBucketInput) error {
return c.WaitUntilBucketNotExistsWithContext(aws.BackgroundContext(), input)
}
// WaitUntilBucketNotExistsWithContext is an extended version of WaitUntilBucketNotExists.
// With the support for passing in a context and options to configure the
// Waiter and the underlying request options.
//
// The context must be non-nil and will be used for request cancellation. If
// the context is nil a panic will occur. In the future the SDK may create
// sub-contexts for http.Requests. See https://golang.org/pkg/context/
// for more information on using Contexts.
func (c *S3) WaitUntilBucketNotExistsWithContext(ctx aws.Context, input *HeadBucketInput, opts ...request.WaiterOption) error {
w := request.Waiter{
Name: "WaitUntilBucketNotExists",
MaxAttempts: 20,
Delay: request.ConstantWaiterDelay(5 * time.Second),
Acceptors: []request.WaiterAcceptor{
{
State: request.SuccessWaiterState,
Matcher: request.StatusWaiterMatch,
Expected: 404,
},
},
Logger: c.Config.Logger,
NewRequest: func(opts []request.Option) (*request.Request, error) {
var inCpy *HeadBucketInput
if input != nil {
tmp := *input
inCpy = &tmp
}
req, _ := c.HeadBucketRequest(inCpy)
req.SetContext(ctx)
req.ApplyOptions(opts...)
return req, nil
},
}
w.ApplyOptions(opts...)
return w.WaitWithContext(ctx)
}
// WaitUntilObjectExists uses the Amazon S3 API operation
// HeadObject to wait for a condition to be met before returning.
// If the condition is not met within the max attempt window, an error will
// be returned.
func (c *S3) WaitUntilObjectExists(input *HeadObjectInput) error {
return c.WaitUntilObjectExistsWithContext(aws.BackgroundContext(), input)
}
// WaitUntilObjectExistsWithContext is an extended version of WaitUntilObjectExists.
// With the support for passing in a context and options to configure the
// Waiter and the underlying request options.
//
// The context must be non-nil and will be used for request cancellation. If
// the context is nil a panic will occur. In the future the SDK may create
// sub-contexts for http.Requests. See https://golang.org/pkg/context/
// for more information on using Contexts.
func (c *S3) WaitUntilObjectExistsWithContext(ctx aws.Context, input *HeadObjectInput, opts ...request.WaiterOption) error {
w := request.Waiter{
Name: "WaitUntilObjectExists",
MaxAttempts: 20,
Delay: request.ConstantWaiterDelay(5 * time.Second),
Acceptors: []request.WaiterAcceptor{
{
State: request.SuccessWaiterState,
Matcher: request.StatusWaiterMatch,
Expected: 200,
},
{
State: request.RetryWaiterState,
Matcher: request.StatusWaiterMatch,
Expected: 404,
},
},
Logger: c.Config.Logger,
NewRequest: func(opts []request.Option) (*request.Request, error) {
var inCpy *HeadObjectInput
if input != nil {
tmp := *input
inCpy = &tmp
}
req, _ := c.HeadObjectRequest(inCpy)
req.SetContext(ctx)
req.ApplyOptions(opts...)
return req, nil
},
}
w.ApplyOptions(opts...)
return w.WaitWithContext(ctx)
}
// WaitUntilObjectNotExists uses the Amazon S3 API operation
// HeadObject to wait for a condition to be met before returning.
// If the condition is not met within the max attempt window, an error will
// be returned.
func (c *S3) WaitUntilObjectNotExists(input *HeadObjectInput) error {
return c.WaitUntilObjectNotExistsWithContext(aws.BackgroundContext(), input)
}
// WaitUntilObjectNotExistsWithContext is an extended version of WaitUntilObjectNotExists.
// With the support for passing in a context and options to configure the
// Waiter and the underlying request options.
//
// The context must be non-nil and will be used for request cancellation. If
// the context is nil a panic will occur. In the future the SDK may create
// sub-contexts for http.Requests. See https://golang.org/pkg/context/
// for more information on using Contexts.
func (c *S3) WaitUntilObjectNotExistsWithContext(ctx aws.Context, input *HeadObjectInput, opts ...request.WaiterOption) error {
w := request.Waiter{
Name: "WaitUntilObjectNotExists",
MaxAttempts: 20,
Delay: request.ConstantWaiterDelay(5 * time.Second),
Acceptors: []request.WaiterAcceptor{
{
State: request.SuccessWaiterState,
Matcher: request.StatusWaiterMatch,
Expected: 404,
},
},
Logger: c.Config.Logger,
NewRequest: func(opts []request.Option) (*request.Request, error) {
var inCpy *HeadObjectInput
if input != nil {
tmp := *input
inCpy = &tmp
}
req, _ := c.HeadObjectRequest(inCpy)
req.SetContext(ctx)
req.ApplyOptions(opts...)
return req, nil
},
}
w.ApplyOptions(opts...)
return w.WaitWithContext(ctx)
}