Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Preserve the snapshot ID when testing for required auth on MD #2542

Merged
merged 13 commits into from
Mar 13, 2024
2 changes: 2 additions & 0 deletions azure-pipelines.yml
Original file line number Diff line number Diff line change
Expand Up @@ -206,7 +206,9 @@ jobs:
AZCOPY_E2E_CLASSIC_ACCOUNT_KEY: $(AZCOPY_E2E_CLASSIC_ACCOUNT_KEY)
AZCOPY_E2E_LOG_OUTPUT: '$(System.DefaultWorkingDirectory)/logs'
AZCOPY_E2E_OAUTH_MANAGED_DISK_CONFIG: $(AZCOPY_E2E_OAUTH_MANAGED_DISK_CONFIG)
AZCOPY_E2E_OAUTH_MANAGED_DISK_SNAPSHOT_CONFIG: $(AZCOPY_E2E_OAUTH_MANAGED_DISK_SNAPSHOT_CONFIG)
AZCOPY_E2E_STD_MANAGED_DISK_CONFIG: $(AZCOPY_E2E_STD_MANAGED_DISK_CONFIG)
AZCOPY_E2E_STD_MANAGED_DISK_SNAPSHOT_CONFIG: $(AZCOPY_E2E_STD_MANAGED_DISK_SNAPSHOT_CONFIG)
CPK_ENCRYPTION_KEY: $(CPK_ENCRYPTION_KEY)
CPK_ENCRYPTION_KEY_SHA256: $(CPK_ENCRYPTION_KEY_SHA256)
AZCOPY_E2E_EXECUTABLE_PATH: $(System.DefaultWorkingDirectory)/$(build_name)
Expand Down
20 changes: 8 additions & 12 deletions cmd/copy.go
Original file line number Diff line number Diff line change
Expand Up @@ -1297,7 +1297,7 @@ func (cca *CookedCopyCmdArgs) processRedirectionDownload(blobResource common.Res
// The isPublic flag is useful in S2S transfers but doesn't much matter for download. Fortunately, no S2S happens here.
// This means that if there's auth, there's auth. We're happy and can move on.
// GetCredentialInfoForLocation also populates oauth token fields... so, it's very easy.
credInfo, _, err := GetCredentialInfoForLocation(ctx, common.ELocation.Blob(), blobResource.Value, blobResource.SAS, true, cca.CpkOptions)
credInfo, _, err := GetCredentialInfoForLocation(ctx, common.ELocation.Blob(), blobResource, true, cca.CpkOptions)

if err != nil {
return fmt.Errorf("fatal: cannot find auth on source blob URL: %s", err.Error())
Expand Down Expand Up @@ -1353,7 +1353,7 @@ func (cca *CookedCopyCmdArgs) processRedirectionUpload(blobResource common.Resou
}

// GetCredentialInfoForLocation populates oauth token fields... so, it's very easy.
credInfo, _, err := GetCredentialInfoForLocation(ctx, common.ELocation.Blob(), blobResource.Value, blobResource.SAS, false, cca.CpkOptions)
credInfo, _, err := GetCredentialInfoForLocation(ctx, common.ELocation.Blob(), blobResource, false, cca.CpkOptions)

if err != nil {
return fmt.Errorf("fatal: cannot find auth on destination blob URL: %s", err.Error())
Expand Down Expand Up @@ -1425,7 +1425,7 @@ func (cca *CookedCopyCmdArgs) getSrcCredential(ctx context.Context, jpo *common.
panic("Invalid Source")
}

srcCredInfo, isPublic, err := GetCredentialInfoForLocation(ctx, cca.FromTo.From(), cca.Source.Value, cca.Source.SAS, true, cca.CpkOptions)
srcCredInfo, isPublic, err := GetCredentialInfoForLocation(ctx, cca.FromTo.From(), cca.Source, true, cca.CpkOptions)
if err != nil {
return srcCredInfo, err
// If S2S and source takes OAuthToken as its cred type (OR) source takes anonymous as its cred type, but it's not public and there's no SAS
Expand Down Expand Up @@ -1483,11 +1483,9 @@ func (cca *CookedCopyCmdArgs) processCopyJobPartOrders() (err error) {
// For upload&download, only one side need credential.
// For S2S copy, as azcopy-v10 use Put*FromUrl, only one credential is needed for destination.
if cca.credentialInfo.CredentialType, err = getCredentialType(ctx, rawFromToInfo{
fromTo: cca.FromTo,
source: cca.Source.Value,
destination: cca.Destination.Value,
sourceSAS: cca.Source.SAS,
destinationSAS: cca.Destination.SAS,
fromTo: cca.FromTo,
source: cca.Source,
destination: cca.Destination,
}, cca.CpkOptions); err != nil {
return err
}
Expand Down Expand Up @@ -1556,10 +1554,9 @@ func (cca *CookedCopyCmdArgs) processCopyJobPartOrders() (err error) {
if err != nil {
return err
}
sourceURL, _ := cca.Source.String()
jobPartOrder.SrcServiceClient, err = common.GetServiceClientForLocation(
cca.FromTo.From(),
sourceURL,
cca.Source,
srcCredInfo.CredentialType,
srcCredInfo.OAuthTokenInfo.TokenCredential,
&options,
Expand All @@ -1575,7 +1572,6 @@ func (cca *CookedCopyCmdArgs) processCopyJobPartOrders() (err error) {
AllowSourceTrailingDot: cca.trailingDot == common.ETrailingDotOption.Enable() && cca.FromTo.From() == common.ELocation.File(),
}
}
dstURL, _ := cca.Destination.String()

var srcCred *common.ScopedCredential
if cca.FromTo.IsS2S() && srcCredInfo.CredentialType.IsAzureOAuth() {
Expand All @@ -1584,7 +1580,7 @@ func (cca *CookedCopyCmdArgs) processCopyJobPartOrders() (err error) {
options = createClientOptions(common.AzcopyCurrentJobLogger, srcCred)
jobPartOrder.DstServiceClient, err = common.GetServiceClientForLocation(
cca.FromTo.To(),
dstURL,
cca.Destination,
cca.credentialInfo.CredentialType,
cca.credentialInfo.OAuthTokenInfo.TokenCredential,
&options,
Expand Down
11 changes: 3 additions & 8 deletions cmd/copyEnumeratorInit.go
Original file line number Diff line number Diff line change
Expand Up @@ -339,7 +339,7 @@ func (cca *CookedCopyCmdArgs) isDestDirectory(dst common.ResourceString, ctx *co
return false
}

if dstCredInfo, _, err = GetCredentialInfoForLocation(*ctx, cca.FromTo.To(), cca.Destination.Value, cca.Destination.SAS, false, cca.CpkOptions); err != nil {
if dstCredInfo, _, err = GetCredentialInfoForLocation(*ctx, cca.FromTo.To(), cca.Destination, false, cca.CpkOptions); err != nil {
return false
}

Expand Down Expand Up @@ -436,23 +436,18 @@ func (cca *CookedCopyCmdArgs) createDstContainer(containerName string, dstWithSA
existingContainers[containerName] = true

var dstCredInfo common.CredentialInfo
dstURL, err := dstWithSAS.String()
if err != nil {
return err
}

// 3minutes is enough time to list properties of a container, and create new if it does not exist.
ctx, cancel := context.WithTimeout(parentCtx, time.Minute*3)
defer cancel()
if dstCredInfo, _, err = GetCredentialInfoForLocation(ctx, cca.FromTo.To(), cca.Destination.Value, cca.Destination.SAS, false, cca.CpkOptions); err != nil {
if dstCredInfo, _, err = GetCredentialInfoForLocation(ctx, cca.FromTo.To(), cca.Destination, false, cca.CpkOptions); err != nil {
return err
}

options := createClientOptions(common.AzcopyCurrentJobLogger, nil)

sc, err := common.GetServiceClientForLocation(
cca.FromTo.To(),
dstURL,
dstWithSAS,
dstCredInfo.CredentialType,
dstCredInfo.OAuthTokenInfo.TokenCredential,
&options,
Expand Down
35 changes: 17 additions & 18 deletions cmd/credentialUtil.go
Original file line number Diff line number Diff line change
Expand Up @@ -209,9 +209,8 @@ func GetCredTypeFromEnvVar() common.CredentialType {
}

type rawFromToInfo struct {
fromTo common.FromTo
source, destination string
sourceSAS, destinationSAS string // Standalone SAS which might be provided
fromTo common.FromTo
source, destination common.ResourceString
}

const trustedSuffixesNameAAD = "trusted-microsoft-suffixes"
Expand Down Expand Up @@ -369,7 +368,7 @@ func isPublic(ctx context.Context, blobResourceURL string, cpkOptions common.Cpk
return false
}

// This request will not be logged. This can fail, and too many Cx do not like this.
// This request will not be logged. This can fail, and too many Cx do not like this.
clientOptions := ste.NewClientOptions(policy.RetryOptions{
MaxRetries: ste.UploadMaxTries,
TryTimeout: ste.UploadTryTimeout,
Expand Down Expand Up @@ -402,7 +401,7 @@ func isPublic(ctx context.Context, blobResourceURL string, cpkOptions common.Cpk

// mdAccountNeedsOAuth pings the passed in md account, and checks if we need additional token with Disk-socpe
func mdAccountNeedsOAuth(ctx context.Context, blobResourceURL string, cpkOptions common.CpkOptions) bool {
// This request will not be logged. This can fail, and too many Cx do not like this.
// This request will not be logged. This can fail, and too many Cx do not like this.
clientOptions := ste.NewClientOptions(policy.RetryOptions{
MaxRetries: ste.UploadMaxTries,
TryTimeout: ste.UploadTryTimeout,
Expand Down Expand Up @@ -430,11 +429,11 @@ func mdAccountNeedsOAuth(ctx context.Context, blobResourceURL string, cpkOptions
return false
}

func getCredentialTypeForLocation(ctx context.Context, location common.Location, resource, resourceSAS string, isSource bool, cpkOptions common.CpkOptions) (credType common.CredentialType, isPublic bool, err error) {
return doGetCredentialTypeForLocation(ctx, location, resource, resourceSAS, isSource, GetCredTypeFromEnvVar, cpkOptions)
func getCredentialTypeForLocation(ctx context.Context, location common.Location, resource common.ResourceString, isSource bool, cpkOptions common.CpkOptions) (credType common.CredentialType, isPublic bool, err error) {
return doGetCredentialTypeForLocation(ctx, location, resource, isSource, GetCredTypeFromEnvVar, cpkOptions)
}

func doGetCredentialTypeForLocation(ctx context.Context, location common.Location, resource, resourceSAS string, isSource bool, getForcedCredType func() common.CredentialType, cpkOptions common.CpkOptions) (credType common.CredentialType, public bool, err error) {
func doGetCredentialTypeForLocation(ctx context.Context, location common.Location, resource common.ResourceString, isSource bool, getForcedCredType func() common.CredentialType, cpkOptions common.CpkOptions) (credType common.CredentialType, public bool, err error) {
public = false
err = nil

Expand All @@ -453,7 +452,7 @@ func doGetCredentialTypeForLocation(ctx context.Context, location common.Locatio
return
}

if err = checkAuthSafeForTarget(credType, resource, cmdLineExtraSuffixesAAD, location); err != nil {
if err = checkAuthSafeForTarget(credType, resource.Value, cmdLineExtraSuffixesAAD, location); err != nil {
credType = common.ECredentialType.Unknown()
public = false
}
Expand Down Expand Up @@ -489,14 +488,14 @@ func doGetCredentialTypeForLocation(ctx context.Context, location common.Locatio

// Special blob destinations - public and MD account needing oAuth
if location == common.ELocation.Blob() {
if isSource && resourceSAS == "" && isPublic(ctx, resource, cpkOptions) {
uri, _ := resource.FullURL()
if isSource && resource.SAS == "" && isPublic(ctx, uri.String(), cpkOptions) {
credType = common.ECredentialType.Anonymous()
public = true
return
}

uri, _ := url.Parse(resource)
if strings.HasPrefix(uri.Host, "md-") && mdAccountNeedsOAuth(ctx, resource, cpkOptions) {
if strings.HasPrefix(uri.Host, "md-") && mdAccountNeedsOAuth(ctx, uri.String(), cpkOptions) {
if !oAuthTokenExists() {
return common.ECredentialType.Unknown(), false,
common.NewAzError(common.EAzError.LoginCredMissing(), "No SAS token or OAuth token is present and the resource is not public")
Expand All @@ -507,7 +506,7 @@ func doGetCredentialTypeForLocation(ctx context.Context, location common.Locatio
}
}

if resourceSAS != "" {
if resource.SAS != "" {
credType = common.ECredentialType.Anonymous()
return
}
Expand Down Expand Up @@ -535,10 +534,10 @@ func doGetCredentialTypeForLocation(ctx context.Context, location common.Locatio
return
}

func GetCredentialInfoForLocation(ctx context.Context, location common.Location, resource, resourceSAS string, isSource bool, cpkOptions common.CpkOptions) (credInfo common.CredentialInfo, isPublic bool, err error) {
func GetCredentialInfoForLocation(ctx context.Context, location common.Location, resource common.ResourceString, isSource bool, cpkOptions common.CpkOptions) (credInfo common.CredentialInfo, isPublic bool, err error) {

// get the type
credInfo.CredentialType, isPublic, err = getCredentialTypeForLocation(ctx, location, resource, resourceSAS, isSource, cpkOptions)
credInfo.CredentialType, isPublic, err = getCredentialTypeForLocation(ctx, location, resource, isSource, cpkOptions)

// flesh out the rest of the fields, for those types that require it
if credInfo.CredentialType.IsAzureOAuth() {
Expand All @@ -563,17 +562,17 @@ func getCredentialType(ctx context.Context, raw rawFromToInfo, cpkOptions common
switch {
case raw.fromTo.To().IsRemote():
// we authenticate to the destination. Source is assumed to be SAS, or public, or a local resource
credType, _, err = getCredentialTypeForLocation(ctx, raw.fromTo.To(), raw.destination, raw.destinationSAS, false, common.CpkOptions{})
credType, _, err = getCredentialTypeForLocation(ctx, raw.fromTo.To(), raw.destination, false, common.CpkOptions{})
case raw.fromTo == common.EFromTo.BlobTrash() ||
raw.fromTo == common.EFromTo.BlobFSTrash() ||
raw.fromTo == common.EFromTo.FileTrash():
// For to Trash direction, use source as resource URL
// Also, by setting isSource=false we inform getCredentialTypeForLocation() that resource
// being deleted cannot be public.
credType, _, err = getCredentialTypeForLocation(ctx, raw.fromTo.From(), raw.source, raw.sourceSAS, false, cpkOptions)
credType, _, err = getCredentialTypeForLocation(ctx, raw.fromTo.From(), raw.source, false, cpkOptions)
case raw.fromTo.From().IsRemote() && raw.fromTo.To().IsLocal():
// we authenticate to the source.
credType, _, err = getCredentialTypeForLocation(ctx, raw.fromTo.From(), raw.source, raw.sourceSAS, true, cpkOptions)
credType, _, err = getCredentialTypeForLocation(ctx, raw.fromTo.From(), raw.source, true, cpkOptions)
default:
credType = common.ECredentialType.Anonymous()
// Log the FromTo types which getCredentialType hasn't solved, in case of miss-use.
Expand Down
31 changes: 18 additions & 13 deletions cmd/jobsResume.go
Original file line number Diff line number Diff line change
Expand Up @@ -248,8 +248,8 @@ type resumeCmdArgs struct {
func (rca resumeCmdArgs) getSourceAndDestinationServiceClients(
ctx context.Context,
fromTo common.FromTo,
source string,
destination string,
source common.ResourceString,
destination common.ResourceString,
) (*common.ServiceClient, *common.ServiceClient, error) {
if len(rca.SourceSAS) > 0 && rca.SourceSAS[0] != '?' {
rca.SourceSAS = "?" + rca.SourceSAS
Expand All @@ -258,10 +258,12 @@ func (rca resumeCmdArgs) getSourceAndDestinationServiceClients(
rca.DestinationSAS = "?" + rca.DestinationSAS
}

source.SAS = rca.SourceSAS
destination.SAS = rca.DestinationSAS

srcCredType, _, err := getCredentialTypeForLocation(ctx,
fromTo.From(),
source,
rca.SourceSAS,
true,
common.CpkOptions{})
if err != nil {
Expand All @@ -271,7 +273,6 @@ func (rca resumeCmdArgs) getSourceAndDestinationServiceClients(
dstCredType, _, err := getCredentialTypeForLocation(ctx,
fromTo.To(),
destination,
rca.DestinationSAS,
false,
common.CpkOptions{})
if err != nil {
Expand All @@ -295,7 +296,7 @@ func (rca resumeCmdArgs) getSourceAndDestinationServiceClients(

options := createClientOptions(common.AzcopyCurrentJobLogger, nil)

srcServiceClient, err := common.GetServiceClientForLocation(fromTo.From(), source+rca.SourceSAS, srcCredType, tc, &options, nil)
srcServiceClient, err := common.GetServiceClientForLocation(fromTo.From(), source, srcCredType, tc, &options, nil)
if err != nil {
return nil, nil, err
}
Expand All @@ -305,7 +306,7 @@ func (rca resumeCmdArgs) getSourceAndDestinationServiceClients(
srcCred = common.NewScopedCredential(tc, srcCredType)
}
options = createClientOptions(common.AzcopyCurrentJobLogger, srcCred)
dstServiceClient, err := common.GetServiceClientForLocation(fromTo.To(), destination+rca.DestinationSAS, dstCredType, tc, &options, nil)
dstServiceClient, err := common.GetServiceClientForLocation(fromTo.To(), destination, dstCredType, tc, &options, nil)
if err != nil {
return nil, nil, err
}
Expand Down Expand Up @@ -375,23 +376,27 @@ func (rca resumeCmdArgs) process() error {
// Initialize credential info.
credentialInfo := common.CredentialInfo{}
// TODO: Replace context with root context
srcResourceString, err := SplitResourceString(getJobFromToResponse.Source, getJobFromToResponse.FromTo.From())
_ = err // todo
srcResourceString.SAS = rca.SourceSAS
dstResourceString, err := SplitResourceString(getJobFromToResponse.Destination, getJobFromToResponse.FromTo.To())
_ = err // todo
dstResourceString.SAS = rca.DestinationSAS

// we should stop using credentiaLInfo and use the clients instead. But before we fix
// that there will be repeated calls to get Credential type for correctness.
if credentialInfo.CredentialType, err = getCredentialType(ctx, rawFromToInfo{
fromTo: getJobFromToResponse.FromTo,
source: getJobFromToResponse.Source,
destination: getJobFromToResponse.Destination,
sourceSAS: rca.SourceSAS,
destinationSAS: rca.DestinationSAS,
fromTo: getJobFromToResponse.FromTo,
source: srcResourceString,
destination: dstResourceString,
}, common.CpkOptions{}); err != nil {
return err
}

srcServiceClient, dstServiceClient, err := rca.getSourceAndDestinationServiceClients(
ctx, getJobFromToResponse.FromTo,
getJobFromToResponse.Source,
getJobFromToResponse.Destination,
srcResourceString,
dstResourceString,
)
if err != nil {
return errors.New("could not create service clients " + err.Error())
Expand Down
2 changes: 1 addition & 1 deletion cmd/list.go
Original file line number Diff line number Diff line change
Expand Up @@ -237,7 +237,7 @@ func (cooked cookedListCmdArgs) HandleListContainerCommand() (err error) {
}

// isSource is rather misnomer for canBePublic. We can list public containers, and hence isSource=true
if credentialInfo, _, err = GetCredentialInfoForLocation(ctx, cooked.location, source.Value, source.SAS, true, common.CpkOptions{}); err != nil {
if credentialInfo, _, err = GetCredentialInfoForLocation(ctx, cooked.location, source, true, common.CpkOptions{}); err != nil {
return fmt.Errorf("failed to obtain credential info: %s", err.Error())
} else if cooked.location == cooked.location.File() && source.SAS == "" {
return errors.New("azure files requires a SAS token for authentication")
Expand Down
2 changes: 1 addition & 1 deletion cmd/make.go
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ func (cookedArgs cookedMakeCmdArgs) process() (err error) {
return fmt.Errorf("failed to resolve target: %w", err)
}

credentialInfo, _, err := GetCredentialInfoForLocation(ctx, cookedArgs.resourceLocation, resourceStringParts.Value, resourceStringParts.SAS, false, common.CpkOptions{})
credentialInfo, _, err := GetCredentialInfoForLocation(ctx, cookedArgs.resourceLocation, resourceStringParts, false, common.CpkOptions{})
if err != nil {
return err
}
Expand Down
5 changes: 2 additions & 3 deletions cmd/removeEnumerator.go
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,6 @@ func newRemoveEnumerator(cca *CookedCopyCmdArgs) (enumerator *CopyEnumerator, er
jobsAdmin.JobsAdmin.LogToJobLog(message, common.LogInfo)
}

targetURL, _ := cca.Source.String()
from := cca.FromTo.From()
if !from.SupportsTrailingDot() {
cca.trailingDot = common.ETrailingDotOption.Disable()
Expand All @@ -97,7 +96,7 @@ func newRemoveEnumerator(cca *CookedCopyCmdArgs) (enumerator *CopyEnumerator, er
}
targetServiceClient, err := common.GetServiceClientForLocation(
cca.FromTo.From(),
targetURL,
cca.Source,
cca.credentialInfo.CredentialType,
cca.credentialInfo.OAuthTokenInfo.TokenCredential,
&options,
Expand Down Expand Up @@ -144,7 +143,7 @@ func removeBfsResources(cca *CookedCopyCmdArgs) (err error) {
sourceURL, _ := cca.Source.String()
options := createClientOptions(common.AzcopyCurrentJobLogger, nil)

targetServiceClient, err := common.GetServiceClientForLocation(cca.FromTo.From(), sourceURL, cca.credentialInfo.CredentialType, cca.credentialInfo.OAuthTokenInfo.TokenCredential, &options, nil)
targetServiceClient, err := common.GetServiceClientForLocation(cca.FromTo.From(), cca.Source, cca.credentialInfo.CredentialType, cca.credentialInfo.OAuthTokenInfo.TokenCredential, &options, nil)
if err != nil {
return err
}
Expand Down
5 changes: 2 additions & 3 deletions cmd/setPropertiesEnumerator.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ func setPropertiesEnumerator(cca *CookedCopyCmdArgs) (enumerator *CopyEnumerator

var srcCredInfo common.CredentialInfo

if srcCredInfo, _, err = GetCredentialInfoForLocation(ctx, cca.FromTo.From(), cca.Source.Value, cca.Source.SAS, true, cca.CpkOptions); err != nil {
if srcCredInfo, _, err = GetCredentialInfoForLocation(ctx, cca.FromTo.From(), cca.Source, true, cca.CpkOptions); err != nil {
return nil, err
}
if cca.FromTo == common.EFromTo.FileNone() && (srcCredInfo.CredentialType == common.ECredentialType.Anonymous() && cca.Source.SAS == "") {
Expand Down Expand Up @@ -72,7 +72,6 @@ func setPropertiesEnumerator(cca *CookedCopyCmdArgs) (enumerator *CopyEnumerator
jobsAdmin.JobsAdmin.LogToJobLog(message, common.LogInfo)
}

targetURL, _ := cca.Source.String()
options := createClientOptions(common.AzcopyCurrentJobLogger, nil)
var fileClientOptions any
if cca.FromTo.From() == common.ELocation.File() {
Expand All @@ -81,7 +80,7 @@ func setPropertiesEnumerator(cca *CookedCopyCmdArgs) (enumerator *CopyEnumerator

targetServiceClient, err := common.GetServiceClientForLocation(
cca.FromTo.From(),
targetURL,
cca.Source,
cca.credentialInfo.CredentialType,
cca.credentialInfo.OAuthTokenInfo.TokenCredential,
&options,
Expand Down
Loading
Loading