diff --git a/azure-pipelines.yml b/azure-pipelines.yml index ad8d8ae81..00f534c08 100644 --- a/azure-pipelines.yml +++ b/azure-pipelines.yml @@ -206,7 +206,9 @@ jobs: AZCOPY_E2E_CLASSIC_ACCOUNT_KEY: $(AZCOPY_E2E_CLASSIC_ACCOUNT_KEY) AZCOPY_E2E_LOG_OUTPUT: '$(System.DefaultWorkingDirectory)/logs' AZCOPY_E2E_OAUTH_MANAGED_DISK_CONFIG: $(AZCOPY_E2E_OAUTH_MANAGED_DISK_CONFIG) + AZCOPY_E2E_OAUTH_MANAGED_DISK_SNAPSHOT_CONFIG: $(AZCOPY_E2E_OAUTH_MANAGED_DISK_SNAPSHOT_CONFIG) AZCOPY_E2E_STD_MANAGED_DISK_CONFIG: $(AZCOPY_E2E_STD_MANAGED_DISK_CONFIG) + AZCOPY_E2E_STD_MANAGED_DISK_SNAPSHOT_CONFIG: $(AZCOPY_E2E_STD_MANAGED_DISK_SNAPSHOT_CONFIG) CPK_ENCRYPTION_KEY: $(CPK_ENCRYPTION_KEY) CPK_ENCRYPTION_KEY_SHA256: $(CPK_ENCRYPTION_KEY_SHA256) AZCOPY_E2E_EXECUTABLE_PATH: $(System.DefaultWorkingDirectory)/$(build_name) diff --git a/cmd/copy.go b/cmd/copy.go index 90ea2a1d7..0030d4c79 100644 --- a/cmd/copy.go +++ b/cmd/copy.go @@ -1297,7 +1297,7 @@ func (cca *CookedCopyCmdArgs) processRedirectionDownload(blobResource common.Res // The isPublic flag is useful in S2S transfers but doesn't much matter for download. Fortunately, no S2S happens here. // This means that if there's auth, there's auth. We're happy and can move on. // GetCredentialInfoForLocation also populates oauth token fields... so, it's very easy. - credInfo, _, err := GetCredentialInfoForLocation(ctx, common.ELocation.Blob(), blobResource.Value, blobResource.SAS, true, cca.CpkOptions) + credInfo, _, err := GetCredentialInfoForLocation(ctx, common.ELocation.Blob(), blobResource, true, cca.CpkOptions) if err != nil { return fmt.Errorf("fatal: cannot find auth on source blob URL: %s", err.Error()) @@ -1353,7 +1353,7 @@ func (cca *CookedCopyCmdArgs) processRedirectionUpload(blobResource common.Resou } // GetCredentialInfoForLocation populates oauth token fields... so, it's very easy. - credInfo, _, err := GetCredentialInfoForLocation(ctx, common.ELocation.Blob(), blobResource.Value, blobResource.SAS, false, cca.CpkOptions) + credInfo, _, err := GetCredentialInfoForLocation(ctx, common.ELocation.Blob(), blobResource, false, cca.CpkOptions) if err != nil { return fmt.Errorf("fatal: cannot find auth on destination blob URL: %s", err.Error()) @@ -1425,7 +1425,7 @@ func (cca *CookedCopyCmdArgs) getSrcCredential(ctx context.Context, jpo *common. panic("Invalid Source") } - srcCredInfo, isPublic, err := GetCredentialInfoForLocation(ctx, cca.FromTo.From(), cca.Source.Value, cca.Source.SAS, true, cca.CpkOptions) + srcCredInfo, isPublic, err := GetCredentialInfoForLocation(ctx, cca.FromTo.From(), cca.Source, true, cca.CpkOptions) if err != nil { return srcCredInfo, err // If S2S and source takes OAuthToken as its cred type (OR) source takes anonymous as its cred type, but it's not public and there's no SAS @@ -1483,11 +1483,9 @@ func (cca *CookedCopyCmdArgs) processCopyJobPartOrders() (err error) { // For upload&download, only one side need credential. // For S2S copy, as azcopy-v10 use Put*FromUrl, only one credential is needed for destination. if cca.credentialInfo.CredentialType, err = getCredentialType(ctx, rawFromToInfo{ - fromTo: cca.FromTo, - source: cca.Source.Value, - destination: cca.Destination.Value, - sourceSAS: cca.Source.SAS, - destinationSAS: cca.Destination.SAS, + fromTo: cca.FromTo, + source: cca.Source, + destination: cca.Destination, }, cca.CpkOptions); err != nil { return err } @@ -1556,10 +1554,9 @@ func (cca *CookedCopyCmdArgs) processCopyJobPartOrders() (err error) { if err != nil { return err } - sourceURL, _ := cca.Source.String() jobPartOrder.SrcServiceClient, err = common.GetServiceClientForLocation( cca.FromTo.From(), - sourceURL, + cca.Source, srcCredInfo.CredentialType, srcCredInfo.OAuthTokenInfo.TokenCredential, &options, @@ -1575,7 +1572,6 @@ func (cca *CookedCopyCmdArgs) processCopyJobPartOrders() (err error) { AllowSourceTrailingDot: cca.trailingDot == common.ETrailingDotOption.Enable() && cca.FromTo.From() == common.ELocation.File(), } } - dstURL, _ := cca.Destination.String() var srcCred *common.ScopedCredential if cca.FromTo.IsS2S() && srcCredInfo.CredentialType.IsAzureOAuth() { @@ -1584,7 +1580,7 @@ func (cca *CookedCopyCmdArgs) processCopyJobPartOrders() (err error) { options = createClientOptions(common.AzcopyCurrentJobLogger, srcCred) jobPartOrder.DstServiceClient, err = common.GetServiceClientForLocation( cca.FromTo.To(), - dstURL, + cca.Destination, cca.credentialInfo.CredentialType, cca.credentialInfo.OAuthTokenInfo.TokenCredential, &options, diff --git a/cmd/copyEnumeratorInit.go b/cmd/copyEnumeratorInit.go index 0495b6407..4e05212e1 100755 --- a/cmd/copyEnumeratorInit.go +++ b/cmd/copyEnumeratorInit.go @@ -339,7 +339,7 @@ func (cca *CookedCopyCmdArgs) isDestDirectory(dst common.ResourceString, ctx *co return false } - if dstCredInfo, _, err = GetCredentialInfoForLocation(*ctx, cca.FromTo.To(), cca.Destination.Value, cca.Destination.SAS, false, cca.CpkOptions); err != nil { + if dstCredInfo, _, err = GetCredentialInfoForLocation(*ctx, cca.FromTo.To(), cca.Destination, false, cca.CpkOptions); err != nil { return false } @@ -436,15 +436,10 @@ func (cca *CookedCopyCmdArgs) createDstContainer(containerName string, dstWithSA existingContainers[containerName] = true var dstCredInfo common.CredentialInfo - dstURL, err := dstWithSAS.String() - if err != nil { - return err - } - // 3minutes is enough time to list properties of a container, and create new if it does not exist. ctx, cancel := context.WithTimeout(parentCtx, time.Minute*3) defer cancel() - if dstCredInfo, _, err = GetCredentialInfoForLocation(ctx, cca.FromTo.To(), cca.Destination.Value, cca.Destination.SAS, false, cca.CpkOptions); err != nil { + if dstCredInfo, _, err = GetCredentialInfoForLocation(ctx, cca.FromTo.To(), cca.Destination, false, cca.CpkOptions); err != nil { return err } @@ -452,7 +447,7 @@ func (cca *CookedCopyCmdArgs) createDstContainer(containerName string, dstWithSA sc, err := common.GetServiceClientForLocation( cca.FromTo.To(), - dstURL, + dstWithSAS, dstCredInfo.CredentialType, dstCredInfo.OAuthTokenInfo.TokenCredential, &options, diff --git a/cmd/credentialUtil.go b/cmd/credentialUtil.go index 4faf05ae8..13694f60c 100644 --- a/cmd/credentialUtil.go +++ b/cmd/credentialUtil.go @@ -209,9 +209,8 @@ func GetCredTypeFromEnvVar() common.CredentialType { } type rawFromToInfo struct { - fromTo common.FromTo - source, destination string - sourceSAS, destinationSAS string // Standalone SAS which might be provided + fromTo common.FromTo + source, destination common.ResourceString } const trustedSuffixesNameAAD = "trusted-microsoft-suffixes" @@ -369,7 +368,7 @@ func isPublic(ctx context.Context, blobResourceURL string, cpkOptions common.Cpk return false } - // This request will not be logged. This can fail, and too many Cx do not like this. + // This request will not be logged. This can fail, and too many Cx do not like this. clientOptions := ste.NewClientOptions(policy.RetryOptions{ MaxRetries: ste.UploadMaxTries, TryTimeout: ste.UploadTryTimeout, @@ -402,7 +401,7 @@ func isPublic(ctx context.Context, blobResourceURL string, cpkOptions common.Cpk // mdAccountNeedsOAuth pings the passed in md account, and checks if we need additional token with Disk-socpe func mdAccountNeedsOAuth(ctx context.Context, blobResourceURL string, cpkOptions common.CpkOptions) bool { - // This request will not be logged. This can fail, and too many Cx do not like this. + // This request will not be logged. This can fail, and too many Cx do not like this. clientOptions := ste.NewClientOptions(policy.RetryOptions{ MaxRetries: ste.UploadMaxTries, TryTimeout: ste.UploadTryTimeout, @@ -430,11 +429,11 @@ func mdAccountNeedsOAuth(ctx context.Context, blobResourceURL string, cpkOptions return false } -func getCredentialTypeForLocation(ctx context.Context, location common.Location, resource, resourceSAS string, isSource bool, cpkOptions common.CpkOptions) (credType common.CredentialType, isPublic bool, err error) { - return doGetCredentialTypeForLocation(ctx, location, resource, resourceSAS, isSource, GetCredTypeFromEnvVar, cpkOptions) +func getCredentialTypeForLocation(ctx context.Context, location common.Location, resource common.ResourceString, isSource bool, cpkOptions common.CpkOptions) (credType common.CredentialType, isPublic bool, err error) { + return doGetCredentialTypeForLocation(ctx, location, resource, isSource, GetCredTypeFromEnvVar, cpkOptions) } -func doGetCredentialTypeForLocation(ctx context.Context, location common.Location, resource, resourceSAS string, isSource bool, getForcedCredType func() common.CredentialType, cpkOptions common.CpkOptions) (credType common.CredentialType, public bool, err error) { +func doGetCredentialTypeForLocation(ctx context.Context, location common.Location, resource common.ResourceString, isSource bool, getForcedCredType func() common.CredentialType, cpkOptions common.CpkOptions) (credType common.CredentialType, public bool, err error) { public = false err = nil @@ -453,7 +452,7 @@ func doGetCredentialTypeForLocation(ctx context.Context, location common.Locatio return } - if err = checkAuthSafeForTarget(credType, resource, cmdLineExtraSuffixesAAD, location); err != nil { + if err = checkAuthSafeForTarget(credType, resource.Value, cmdLineExtraSuffixesAAD, location); err != nil { credType = common.ECredentialType.Unknown() public = false } @@ -489,14 +488,14 @@ func doGetCredentialTypeForLocation(ctx context.Context, location common.Locatio // Special blob destinations - public and MD account needing oAuth if location == common.ELocation.Blob() { - if isSource && resourceSAS == "" && isPublic(ctx, resource, cpkOptions) { + uri, _ := resource.FullURL() + if isSource && resource.SAS == "" && isPublic(ctx, uri.String(), cpkOptions) { credType = common.ECredentialType.Anonymous() public = true return } - uri, _ := url.Parse(resource) - if strings.HasPrefix(uri.Host, "md-") && mdAccountNeedsOAuth(ctx, resource, cpkOptions) { + if strings.HasPrefix(uri.Host, "md-") && mdAccountNeedsOAuth(ctx, uri.String(), cpkOptions) { if !oAuthTokenExists() { return common.ECredentialType.Unknown(), false, common.NewAzError(common.EAzError.LoginCredMissing(), "No SAS token or OAuth token is present and the resource is not public") @@ -507,7 +506,7 @@ func doGetCredentialTypeForLocation(ctx context.Context, location common.Locatio } } - if resourceSAS != "" { + if resource.SAS != "" { credType = common.ECredentialType.Anonymous() return } @@ -535,10 +534,10 @@ func doGetCredentialTypeForLocation(ctx context.Context, location common.Locatio return } -func GetCredentialInfoForLocation(ctx context.Context, location common.Location, resource, resourceSAS string, isSource bool, cpkOptions common.CpkOptions) (credInfo common.CredentialInfo, isPublic bool, err error) { +func GetCredentialInfoForLocation(ctx context.Context, location common.Location, resource common.ResourceString, isSource bool, cpkOptions common.CpkOptions) (credInfo common.CredentialInfo, isPublic bool, err error) { // get the type - credInfo.CredentialType, isPublic, err = getCredentialTypeForLocation(ctx, location, resource, resourceSAS, isSource, cpkOptions) + credInfo.CredentialType, isPublic, err = getCredentialTypeForLocation(ctx, location, resource, isSource, cpkOptions) // flesh out the rest of the fields, for those types that require it if credInfo.CredentialType.IsAzureOAuth() { @@ -563,17 +562,17 @@ func getCredentialType(ctx context.Context, raw rawFromToInfo, cpkOptions common switch { case raw.fromTo.To().IsRemote(): // we authenticate to the destination. Source is assumed to be SAS, or public, or a local resource - credType, _, err = getCredentialTypeForLocation(ctx, raw.fromTo.To(), raw.destination, raw.destinationSAS, false, common.CpkOptions{}) + credType, _, err = getCredentialTypeForLocation(ctx, raw.fromTo.To(), raw.destination, false, common.CpkOptions{}) case raw.fromTo == common.EFromTo.BlobTrash() || raw.fromTo == common.EFromTo.BlobFSTrash() || raw.fromTo == common.EFromTo.FileTrash(): // For to Trash direction, use source as resource URL // Also, by setting isSource=false we inform getCredentialTypeForLocation() that resource // being deleted cannot be public. - credType, _, err = getCredentialTypeForLocation(ctx, raw.fromTo.From(), raw.source, raw.sourceSAS, false, cpkOptions) + credType, _, err = getCredentialTypeForLocation(ctx, raw.fromTo.From(), raw.source, false, cpkOptions) case raw.fromTo.From().IsRemote() && raw.fromTo.To().IsLocal(): // we authenticate to the source. - credType, _, err = getCredentialTypeForLocation(ctx, raw.fromTo.From(), raw.source, raw.sourceSAS, true, cpkOptions) + credType, _, err = getCredentialTypeForLocation(ctx, raw.fromTo.From(), raw.source, true, cpkOptions) default: credType = common.ECredentialType.Anonymous() // Log the FromTo types which getCredentialType hasn't solved, in case of miss-use. diff --git a/cmd/jobsResume.go b/cmd/jobsResume.go index 1f8554722..32dcfb53e 100644 --- a/cmd/jobsResume.go +++ b/cmd/jobsResume.go @@ -248,8 +248,8 @@ type resumeCmdArgs struct { func (rca resumeCmdArgs) getSourceAndDestinationServiceClients( ctx context.Context, fromTo common.FromTo, - source string, - destination string, + source common.ResourceString, + destination common.ResourceString, ) (*common.ServiceClient, *common.ServiceClient, error) { if len(rca.SourceSAS) > 0 && rca.SourceSAS[0] != '?' { rca.SourceSAS = "?" + rca.SourceSAS @@ -258,10 +258,12 @@ func (rca resumeCmdArgs) getSourceAndDestinationServiceClients( rca.DestinationSAS = "?" + rca.DestinationSAS } + source.SAS = rca.SourceSAS + destination.SAS = rca.DestinationSAS + srcCredType, _, err := getCredentialTypeForLocation(ctx, fromTo.From(), source, - rca.SourceSAS, true, common.CpkOptions{}) if err != nil { @@ -271,7 +273,6 @@ func (rca resumeCmdArgs) getSourceAndDestinationServiceClients( dstCredType, _, err := getCredentialTypeForLocation(ctx, fromTo.To(), destination, - rca.DestinationSAS, false, common.CpkOptions{}) if err != nil { @@ -295,7 +296,7 @@ func (rca resumeCmdArgs) getSourceAndDestinationServiceClients( options := createClientOptions(common.AzcopyCurrentJobLogger, nil) - srcServiceClient, err := common.GetServiceClientForLocation(fromTo.From(), source+rca.SourceSAS, srcCredType, tc, &options, nil) + srcServiceClient, err := common.GetServiceClientForLocation(fromTo.From(), source, srcCredType, tc, &options, nil) if err != nil { return nil, nil, err } @@ -305,7 +306,7 @@ func (rca resumeCmdArgs) getSourceAndDestinationServiceClients( srcCred = common.NewScopedCredential(tc, srcCredType) } options = createClientOptions(common.AzcopyCurrentJobLogger, srcCred) - dstServiceClient, err := common.GetServiceClientForLocation(fromTo.To(), destination+rca.DestinationSAS, dstCredType, tc, &options, nil) + dstServiceClient, err := common.GetServiceClientForLocation(fromTo.To(), destination, dstCredType, tc, &options, nil) if err != nil { return nil, nil, err } @@ -375,23 +376,27 @@ func (rca resumeCmdArgs) process() error { // Initialize credential info. credentialInfo := common.CredentialInfo{} // TODO: Replace context with root context + srcResourceString, err := SplitResourceString(getJobFromToResponse.Source, getJobFromToResponse.FromTo.From()) + _ = err // todo + srcResourceString.SAS = rca.SourceSAS + dstResourceString, err := SplitResourceString(getJobFromToResponse.Destination, getJobFromToResponse.FromTo.To()) + _ = err // todo + dstResourceString.SAS = rca.DestinationSAS // we should stop using credentiaLInfo and use the clients instead. But before we fix // that there will be repeated calls to get Credential type for correctness. if credentialInfo.CredentialType, err = getCredentialType(ctx, rawFromToInfo{ - fromTo: getJobFromToResponse.FromTo, - source: getJobFromToResponse.Source, - destination: getJobFromToResponse.Destination, - sourceSAS: rca.SourceSAS, - destinationSAS: rca.DestinationSAS, + fromTo: getJobFromToResponse.FromTo, + source: srcResourceString, + destination: dstResourceString, }, common.CpkOptions{}); err != nil { return err } srcServiceClient, dstServiceClient, err := rca.getSourceAndDestinationServiceClients( ctx, getJobFromToResponse.FromTo, - getJobFromToResponse.Source, - getJobFromToResponse.Destination, + srcResourceString, + dstResourceString, ) if err != nil { return errors.New("could not create service clients " + err.Error()) diff --git a/cmd/list.go b/cmd/list.go index 4a6437578..0eeefab0c 100755 --- a/cmd/list.go +++ b/cmd/list.go @@ -237,7 +237,7 @@ func (cooked cookedListCmdArgs) HandleListContainerCommand() (err error) { } // isSource is rather misnomer for canBePublic. We can list public containers, and hence isSource=true - if credentialInfo, _, err = GetCredentialInfoForLocation(ctx, cooked.location, source.Value, source.SAS, true, common.CpkOptions{}); err != nil { + if credentialInfo, _, err = GetCredentialInfoForLocation(ctx, cooked.location, source, true, common.CpkOptions{}); err != nil { return fmt.Errorf("failed to obtain credential info: %s", err.Error()) } else if cooked.location == cooked.location.File() && source.SAS == "" { return errors.New("azure files requires a SAS token for authentication") diff --git a/cmd/make.go b/cmd/make.go index e2124cbc4..13c900c20 100644 --- a/cmd/make.go +++ b/cmd/make.go @@ -85,7 +85,7 @@ func (cookedArgs cookedMakeCmdArgs) process() (err error) { return fmt.Errorf("failed to resolve target: %w", err) } - credentialInfo, _, err := GetCredentialInfoForLocation(ctx, cookedArgs.resourceLocation, resourceStringParts.Value, resourceStringParts.SAS, false, common.CpkOptions{}) + credentialInfo, _, err := GetCredentialInfoForLocation(ctx, cookedArgs.resourceLocation, resourceStringParts, false, common.CpkOptions{}) if err != nil { return err } diff --git a/cmd/removeEnumerator.go b/cmd/removeEnumerator.go index 2a67f77e2..7ca0f70c3 100755 --- a/cmd/removeEnumerator.go +++ b/cmd/removeEnumerator.go @@ -85,7 +85,6 @@ func newRemoveEnumerator(cca *CookedCopyCmdArgs) (enumerator *CopyEnumerator, er jobsAdmin.JobsAdmin.LogToJobLog(message, common.LogInfo) } - targetURL, _ := cca.Source.String() from := cca.FromTo.From() if !from.SupportsTrailingDot() { cca.trailingDot = common.ETrailingDotOption.Disable() @@ -97,7 +96,7 @@ func newRemoveEnumerator(cca *CookedCopyCmdArgs) (enumerator *CopyEnumerator, er } targetServiceClient, err := common.GetServiceClientForLocation( cca.FromTo.From(), - targetURL, + cca.Source, cca.credentialInfo.CredentialType, cca.credentialInfo.OAuthTokenInfo.TokenCredential, &options, @@ -144,7 +143,7 @@ func removeBfsResources(cca *CookedCopyCmdArgs) (err error) { sourceURL, _ := cca.Source.String() options := createClientOptions(common.AzcopyCurrentJobLogger, nil) - targetServiceClient, err := common.GetServiceClientForLocation(cca.FromTo.From(), sourceURL, cca.credentialInfo.CredentialType, cca.credentialInfo.OAuthTokenInfo.TokenCredential, &options, nil) + targetServiceClient, err := common.GetServiceClientForLocation(cca.FromTo.From(), cca.Source, cca.credentialInfo.CredentialType, cca.credentialInfo.OAuthTokenInfo.TokenCredential, &options, nil) if err != nil { return err } diff --git a/cmd/setPropertiesEnumerator.go b/cmd/setPropertiesEnumerator.go index f2ebce28d..14e9be84e 100755 --- a/cmd/setPropertiesEnumerator.go +++ b/cmd/setPropertiesEnumerator.go @@ -38,7 +38,7 @@ func setPropertiesEnumerator(cca *CookedCopyCmdArgs) (enumerator *CopyEnumerator var srcCredInfo common.CredentialInfo - if srcCredInfo, _, err = GetCredentialInfoForLocation(ctx, cca.FromTo.From(), cca.Source.Value, cca.Source.SAS, true, cca.CpkOptions); err != nil { + if srcCredInfo, _, err = GetCredentialInfoForLocation(ctx, cca.FromTo.From(), cca.Source, true, cca.CpkOptions); err != nil { return nil, err } if cca.FromTo == common.EFromTo.FileNone() && (srcCredInfo.CredentialType == common.ECredentialType.Anonymous() && cca.Source.SAS == "") { @@ -72,7 +72,6 @@ func setPropertiesEnumerator(cca *CookedCopyCmdArgs) (enumerator *CopyEnumerator jobsAdmin.JobsAdmin.LogToJobLog(message, common.LogInfo) } - targetURL, _ := cca.Source.String() options := createClientOptions(common.AzcopyCurrentJobLogger, nil) var fileClientOptions any if cca.FromTo.From() == common.ELocation.File() { @@ -81,7 +80,7 @@ func setPropertiesEnumerator(cca *CookedCopyCmdArgs) (enumerator *CopyEnumerator targetServiceClient, err := common.GetServiceClientForLocation( cca.FromTo.From(), - targetURL, + cca.Source, cca.credentialInfo.CredentialType, cca.credentialInfo.OAuthTokenInfo.TokenCredential, &options, diff --git a/cmd/sync.go b/cmd/sync.go index 05f3d7119..ef680c03c 100644 --- a/cmd/sync.go +++ b/cmd/sync.go @@ -676,12 +676,12 @@ func (cca *cookedSyncCmdArgs) process() (err error) { // Verifies credential type and initializes credential info. // Note that this is for the destination. - cca.credentialInfo, _, err = GetCredentialInfoForLocation(ctx, cca.fromTo.To(), cca.destination.Value, cca.destination.SAS, false, cca.cpkOptions) + cca.credentialInfo, _, err = GetCredentialInfoForLocation(ctx, cca.fromTo.To(), cca.destination, false, cca.cpkOptions) if err != nil { return err } - srcCredInfo, _, err := GetCredentialInfoForLocation(ctx, cca.fromTo.From(), cca.source.Value, cca.source.SAS, true, cca.cpkOptions) + srcCredInfo, _, err := GetCredentialInfoForLocation(ctx, cca.fromTo.From(), cca.source, true, cca.cpkOptions) if err != nil { return err } diff --git a/cmd/syncEnumerator.go b/cmd/syncEnumerator.go index b241f3023..8000d5850 100644 --- a/cmd/syncEnumerator.go +++ b/cmd/syncEnumerator.go @@ -37,7 +37,7 @@ import ( func (cca *cookedSyncCmdArgs) initEnumerator(ctx context.Context) (enumerator *syncEnumerator, err error) { - srcCredInfo, _, err := GetCredentialInfoForLocation(ctx, cca.fromTo.From(), cca.source.Value, cca.source.SAS, true, cca.cpkOptions) + srcCredInfo, _, err := GetCredentialInfoForLocation(ctx, cca.fromTo.From(), cca.source, true, cca.cpkOptions) if err != nil { return nil, err @@ -73,8 +73,7 @@ func (cca *cookedSyncCmdArgs) initEnumerator(ctx context.Context) (enumerator *s } // Because we can't trust cca.credinfo, given that it's for the overall job, not the individual traversers, we get cred info again here. - dstCredInfo, _, err := GetCredentialInfoForLocation(ctx, cca.fromTo.To(), cca.destination.Value, - cca.destination.SAS, false, cca.cpkOptions) + dstCredInfo, _, err := GetCredentialInfoForLocation(ctx, cca.fromTo.To(), cca.destination, false, cca.cpkOptions) if err != nil { return nil, err @@ -182,7 +181,7 @@ func (cca *cookedSyncCmdArgs) initEnumerator(ctx context.Context) (enumerator *s options := createClientOptions(common.AzcopyCurrentJobLogger, nil) - // Create Source Client. + // Create Source Client. var azureFileSpecificOptions any if cca.fromTo.From() == common.ELocation.File() { azureFileSpecificOptions = &common.FileClientOptions{ @@ -190,10 +189,9 @@ func (cca *cookedSyncCmdArgs) initEnumerator(ctx context.Context) (enumerator *s } } - sourceURL, _ := cca.source.String() copyJobTemplate.SrcServiceClient, err = common.GetServiceClientForLocation( cca.fromTo.From(), - sourceURL, + cca.source, srcCredInfo.CredentialType, srcCredInfo.OAuthTokenInfo.TokenCredential, &options, @@ -217,10 +215,9 @@ func (cca *cookedSyncCmdArgs) initEnumerator(ctx context.Context) (enumerator *s } options = createClientOptions(common.AzcopyCurrentJobLogger, srcTokenCred) - dstURL, _ := cca.destination.String() copyJobTemplate.DstServiceClient, err = common.GetServiceClientForLocation( cca.fromTo.To(), - dstURL, + cca.destination, dstCredInfo.CredentialType, dstCredInfo.OAuthTokenInfo.TokenCredential, &options, diff --git a/cmd/zc_enumerator.go b/cmd/zc_enumerator.go index 1a75cab3f..6e5208c9d 100644 --- a/cmd/zc_enumerator.go +++ b/cmd/zc_enumerator.go @@ -444,7 +444,13 @@ func InitResourceTraverser(resource common.ResourceString, location common.Locat blobURLParts.BlobName = "" blobURLParts.Snapshot = "" blobURLParts.VersionID = "" - c, err := common.GetServiceClientForLocation(common.ELocation.Blob(), blobURLParts.String(), credential.CredentialType, credential.OAuthTokenInfo.TokenCredential, &options, nil) + + res, err := SplitResourceString(blobURLParts.String(), common.ELocation.Blob()) + if err != nil { + return nil, err + } + + c, err := common.GetServiceClientForLocation(common.ELocation.Blob(), res, credential.CredentialType, credential.OAuthTokenInfo.TokenCredential, &options, nil) if err != nil { return nil, err } @@ -489,7 +495,13 @@ func InitResourceTraverser(resource common.ResourceString, location common.Locat fileOptions := &common.FileClientOptions{ AllowTrailingDot: trailingDot == common.ETrailingDotOption.Enable(), } - c, err := common.GetServiceClientForLocation(common.ELocation.File(), fileURLParts.String(), credential.CredentialType, credential.OAuthTokenInfo.TokenCredential, &options, fileOptions) + + res, err := SplitResourceString(fileURLParts.String(), common.ELocation.File()) + if err != nil { + return nil, err + } + + c, err := common.GetServiceClientForLocation(common.ELocation.File(), res, credential.CredentialType, credential.OAuthTokenInfo.TokenCredential, &options, fileOptions) if err != nil { return nil, err } @@ -530,7 +542,12 @@ func InitResourceTraverser(resource common.ResourceString, location common.Locat blobURLParts.Snapshot = "" blobURLParts.VersionID = "" - c, err := common.GetServiceClientForLocation(common.ELocation.Blob(), blobURLParts.String(), credential.CredentialType, credential.OAuthTokenInfo.TokenCredential, &options, nil) + res, err := SplitResourceString(blobURLParts.String(), common.ELocation.Blob()) + if err != nil { + return nil, err + } + + c, err := common.GetServiceClientForLocation(common.ELocation.Blob(), res, credential.CredentialType, credential.OAuthTokenInfo.TokenCredential, &options, nil) if err != nil { return nil, err } diff --git a/cmd/zt_credentialUtil_test.go b/cmd/zt_credentialUtil_test.go index 22c64a222..b1a6c8e4c 100644 --- a/cmd/zt_credentialUtil_test.go +++ b/cmd/zt_credentialUtil_test.go @@ -96,16 +96,19 @@ func TestCheckAuthSafeForTarget(t *testing.T) { } func TestCheckAuthSafeForTargetIsCalledWhenGettingAuthType(t *testing.T) { - common.AzcopyJobPlanFolder = os.TempDir() + common.AzcopyJobPlanFolder = os.TempDir() a := assert.New(t) mockGetCredTypeFromEnvVar := func() common.CredentialType { return common.ECredentialType.OAuthToken() // force it to OAuth, which is the case we want to test } + res, err := SplitResourceString("http://notblob.example.com", common.ELocation.Blob()) + a.NoError(err) + // Call our core cred type getter function, in a way that will fail the safety check, and assert // that it really does fail. // This checks that our safety check is hooked into the main logic - _, _, err := doGetCredentialTypeForLocation(context.Background(), common.ELocation.Blob(), "http://notblob.example.com", "", true, mockGetCredTypeFromEnvVar, common.CpkOptions{}) + _, _, err = doGetCredentialTypeForLocation(context.Background(), common.ELocation.Blob(), res, true, mockGetCredTypeFromEnvVar, common.CpkOptions{}) a.NotNil(err) a.True(strings.Contains(err.Error(), "If this URL is in fact an Azure service, you can enable Azure authentication to notblob.example.com.")) } @@ -116,10 +119,13 @@ func TestCheckAuthSafeForTargetIsCalledWhenGettingAuthTypeMDOAuth(t *testing.T) return common.ECredentialType.MDOAuthToken() // force it to OAuth, which is the case we want to test } + res, err := SplitResourceString("http://notblob.example.com", common.ELocation.Blob()) + a.NoError(err) + // Call our core cred type getter function, in a way that will fail the safety check, and assert // that it really does fail. // This checks that our safety check is hooked into the main logic - _, _, err := doGetCredentialTypeForLocation(context.Background(), common.ELocation.Blob(), "http://notblob.example.com", "", true, mockGetCredTypeFromEnvVar, common.CpkOptions{}) + _, _, err = doGetCredentialTypeForLocation(context.Background(), common.ELocation.Blob(), res, true, mockGetCredTypeFromEnvVar, common.CpkOptions{}) a.NotNil(err) a.True(strings.Contains(err.Error(), "If this URL is in fact an Azure service, you can enable Azure authentication to notblob.example.com.")) } @@ -130,11 +136,11 @@ func TestCheckAuthSafeForTargetIsCalledWhenGettingAuthTypeMDOAuth(t *testing.T) */ func TestIsPublic(t *testing.T) { a := assert.New(t) - ctx, _ := context.WithTimeout(context.TODO(), 5 * time.Minute) + ctx, _ := context.WithTimeout(context.TODO(), 5*time.Minute) bsc := getBlobServiceClient() ctr, _ := getContainerClient(a, bsc) defer ctr.Delete(ctx, nil) - + publicAccess := container.PublicAccessTypeContainer // Create a public container @@ -155,4 +161,4 @@ func TestIsPublic(t *testing.T) { a.True(isPublic(ctx, bb.URL(), common.CpkOptions{})) -} \ No newline at end of file +} diff --git a/common/rpc-models.go b/common/rpc-models.go index 2d5f6b6ff..1e0901a70 100644 --- a/common/rpc-models.go +++ b/common/rpc-models.go @@ -98,6 +98,9 @@ func (r ResourceString) ValueLocal() string { func (r ResourceString) addParamsToUrl(u *url.URL, sas, extraQuery string) { for _, p := range []string{sas, extraQuery} { + // Sanity check: trim ? from the start + p = strings.TrimPrefix(p, "?") + if p == "" { continue } diff --git a/common/util.go b/common/util.go index d50f96a30..c4acd50c7 100644 --- a/common/util.go +++ b/common/util.go @@ -3,19 +3,20 @@ package common import ( "context" "errors" + "fmt" "net" "net/url" "strings" "github.com/Azure/azure-sdk-for-go/sdk/azcore" - "github.com/Azure/azure-sdk-for-go/sdk/storage/azblob/blob" - "github.com/Azure/azure-sdk-for-go/sdk/storage/azdatalake" - "github.com/Azure/azure-sdk-for-go/sdk/storage/azfile/file" "github.com/Azure/azure-sdk-for-go/sdk/azcore/policy" "github.com/Azure/azure-sdk-for-go/sdk/azcore/to" + "github.com/Azure/azure-sdk-for-go/sdk/storage/azblob/blob" blobservice "github.com/Azure/azure-sdk-for-go/sdk/storage/azblob/service" + "github.com/Azure/azure-sdk-for-go/sdk/storage/azdatalake" datalake "github.com/Azure/azure-sdk-for-go/sdk/storage/azdatalake/service" "github.com/Azure/azure-sdk-for-go/sdk/storage/azfile/directory" + "github.com/Azure/azure-sdk-for-go/sdk/storage/azfile/file" "github.com/Azure/azure-sdk-for-go/sdk/storage/azfile/fileerror" fileservice "github.com/Azure/azure-sdk-for-go/sdk/storage/azfile/service" ) @@ -108,13 +109,18 @@ type FileClientOptions struct { // container and file related details before creating the client. locationSpecificOptions // are required currently only for files. func GetServiceClientForLocation(loc Location, - resourceURL string, + resource ResourceString, credType CredentialType, cred azcore.TokenCredential, policyOptions *azcore.ClientOptions, locationSpecificOptions any, ) (*ServiceClient, error) { ret := &ServiceClient{} + resourceURL, err := resource.String() + if err != nil { + return nil, fmt.Errorf("failed to get resource string: %w", err) + } + switch loc { case ELocation.BlobFS(), ELocation.Blob(): // Since we always may need to interact with DFS while working with Blob, we should just attach both. datalakeURLParts, err := azdatalake.ParseURL(resourceURL) @@ -241,7 +247,7 @@ func NewScopedCredential(cred azcore.TokenCredential, credType CredentialType) * var scope string if !credType.IsAzureOAuth() { return nil - } else if credType == ECredentialType.MDOAuthToken() { + } else if credType == ECredentialType.MDOAuthToken() { scope = ManagedDiskScope } else if credType == ECredentialType.OAuthToken() { scope = StorageScope @@ -250,13 +256,13 @@ func NewScopedCredential(cred azcore.TokenCredential, credType CredentialType) * } type ScopedCredential struct { - cred azcore.TokenCredential + cred azcore.TokenCredential scopes []string } func (s *ScopedCredential) GetToken(ctx context.Context, - _ policy.TokenRequestOptions)( - azcore.AccessToken, error) { + _ policy.TokenRequestOptions) ( + azcore.AccessToken, error) { return s.cred.GetToken(ctx, policy.TokenRequestOptions{Scopes: s.scopes}) } @@ -289,9 +295,9 @@ func (s *ServiceClient) DatalakeServiceClient() (*datalake.Client, error) { // This is currently used only in testcases func NewServiceClient(bsc *blobservice.Client, - fsc *fileservice.Client, - dsc *datalake.Client) *ServiceClient { - return &ServiceClient { + fsc *fileservice.Client, + dsc *datalake.Client) *ServiceClient { + return &ServiceClient{ bsc: bsc, fsc: fsc, dsc: dsc, @@ -388,4 +394,4 @@ func DoWithOverrideReadOnlyOnAzureFiles(ctx context.Context, action func() (inte // retry the action _, err = action() return err -} \ No newline at end of file +} diff --git a/e2etest/.azcopy/latest_version.txt b/e2etest/.azcopy/latest_version.txt new file mode 100644 index 000000000..636da43e4 --- /dev/null +++ b/e2etest/.azcopy/latest_version.txt @@ -0,0 +1 @@ +10.22.2,2024-01-18T12:52:03Z \ No newline at end of file diff --git a/e2etest/config.go b/e2etest/config.go index 72efc6a32..dfa63ddf0 100644 --- a/e2etest/config.go +++ b/e2etest/config.go @@ -121,13 +121,15 @@ func (AccountType) OAuthManagedDisk() AccountType { return AccountTy func (AccountType) S3() AccountType { return AccountType(6) } // Stub, for future testing use func (AccountType) GCP() AccountType { return AccountType(7) } // Stub, for future testing use func (AccountType) Azurite() AccountType { return AccountType(8) } +func (AccountType) ManagedDiskSnapshot() AccountType { return AccountType(9) } +func (AccountType) ManagedDiskSnapshotOAuth() AccountType { return AccountType(10) } func (o AccountType) String() string { return enum.StringInt(o, reflect.TypeOf(o)) } func (o AccountType) IsManagedDisk() bool { - return o == o.StdManagedDisk() || o == o.OAuthManagedDisk() + return o == o.StdManagedDisk() || o == o.OAuthManagedDisk() || o == o.ManagedDiskSnapshot() || o == o.ManagedDiskSnapshotOAuth() } func (o AccountType) IsBlobOnly() bool { @@ -141,19 +143,28 @@ type ManagedDiskConfig struct { SubscriptionID string ResourceGroupName string DiskName string + oauth AccessToken + isSnapshot bool } var ClassicE2EOAuthCache *OAuthCache func (gim GlobalInputManager) GetMDConfig(accountType AccountType) (*ManagedDiskConfig, error) { var mdConfigVar string + var isSnapshot bool switch accountType { case EAccountType.StdManagedDisk(): mdConfigVar = "AZCOPY_E2E_STD_MANAGED_DISK_CONFIG" case EAccountType.OAuthManagedDisk(): mdConfigVar = "AZCOPY_E2E_OAUTH_MANAGED_DISK_CONFIG" + case EAccountType.ManagedDiskSnapshot(): + mdConfigVar = "AZCOPY_E2E_STD_MANAGED_DISK_SNAPSHOT_CONFIG" + isSnapshot = true + case EAccountType.ManagedDiskSnapshotOAuth(): + mdConfigVar = "AZCOPY_E2E_OAUTH_MANAGED_DISK_SNAPSHOT_CONFIG" + isSnapshot = true default: return nil, fmt.Errorf("account type %s is invalid for GetMDConfig", accountType.String()) } @@ -169,6 +180,8 @@ func (gim GlobalInputManager) GetMDConfig(accountType AccountType) (*ManagedDisk return nil, fmt.Errorf("failed to parse config") // Outputting the error may reveal semi-sensitive info like subscription ID } + // Attach additional details to the config + out.isSnapshot = isSnapshot err = gim.SetupClassicOAuthCache() if err != nil { return nil, fmt.Errorf("failed to setup OAuth cache: %w", err) diff --git a/e2etest/declarativeResourceManagers.go b/e2etest/declarativeResourceManagers.go index 5e3a74dc5..5bb33bec2 100644 --- a/e2etest/declarativeResourceManagers.go +++ b/e2etest/declarativeResourceManagers.go @@ -24,6 +24,7 @@ import ( "fmt" "github.com/Azure/azure-sdk-for-go/sdk/storage/azblob/blob" "github.com/Azure/azure-sdk-for-go/sdk/storage/azblob/container" + blobsas "github.com/Azure/azure-sdk-for-go/sdk/storage/azblob/sas" "github.com/Azure/azure-sdk-for-go/sdk/storage/azdatalake/datalakeerror" datalakedirectory "github.com/Azure/azure-sdk-for-go/sdk/storage/azdatalake/directory" "github.com/Azure/azure-sdk-for-go/sdk/storage/azfile/file" @@ -548,6 +549,11 @@ func (r *resourceManagedDisk) createLocation(a asserter, s *scenario) { uri, err := r.config.GetAccess() a.AssertNoErr(err) + snapshotID := uri.Query().Get("snapshot") + if r.config.isSnapshot { + a.Assert(snapshotID, notEquals(), "", "Snapshot target must be incremental, or no snapshot query value is present") + } + r.accessURI = uri } @@ -570,8 +576,11 @@ func (r *resourceManagedDisk) downloadContent(a asserter, options downloadConten // cleanup also usurps traditional resourceManager functionality. func (r *resourceManagedDisk) cleanup(a asserter) { - // revoking access isn't required and causes funky behaviour for testing that might require a distributed mutex. - // todo: we should create managed disks as needed with the requirements rather than using a single MD should we plan to do read-write tests. + err := r.config.RevokeAccess() + a.AssertNoErr(err) + + // The signed identifier cache supposedly lasts 30s, so we'll assume that's a safe break time. + time.Sleep(time.Second * 30) } // getParam works functionally different because resourceManagerDisk inherently only targets a single file. @@ -579,7 +588,11 @@ func (r *resourceManagedDisk) getParam(a asserter, stripTopDir, withSas bool, wi out := *r.accessURI // clone the URI if !withSas { - out.RawQuery = "" + //out.RawQuery = "" + parts, err := blob.ParseURL(out.String()) + a.AssertNoErr(err, "url should parse, sanity check") + parts.SAS = blobsas.QueryParameters{} + return parts.String() } toReturn := out.String() diff --git a/e2etest/declarativeRunner.go b/e2etest/declarativeRunner.go index 3775e6017..eefc08bd7 100644 --- a/e2etest/declarativeRunner.go +++ b/e2etest/declarativeRunner.go @@ -81,12 +81,12 @@ func getValidCredCombinationsForFromTo(fromTo common.FromTo, requestedCredential } for _, srcCredType := range sourceTypes { - if srcCredType == common.ECredentialType.MDOAuthToken() && accountTypes[0] != EAccountType.OAuthManagedDisk() { + if srcCredType == common.ECredentialType.MDOAuthToken() && accountTypes[0] != EAccountType.OAuthManagedDisk() && accountTypes[0] != EAccountType.ManagedDiskSnapshotOAuth() { continue // invalid selection } for _, dstCredType := range validCredTypesPerLocation[fromTo.To()] { - if dstCredType == common.ECredentialType.MDOAuthToken() && accountTypes[1] != EAccountType.OAuthManagedDisk() { + if dstCredType == common.ECredentialType.MDOAuthToken() && accountTypes[1] != EAccountType.OAuthManagedDisk() && accountTypes[0] != EAccountType.ManagedDiskSnapshotOAuth() { continue // invalid selection } diff --git a/e2etest/declarativeScenario.go b/e2etest/declarativeScenario.go index ef70b9e96..9ba5949e7 100644 --- a/e2etest/declarativeScenario.go +++ b/e2etest/declarativeScenario.go @@ -113,6 +113,8 @@ func (s *scenario) Run() { if s.destAccountType.IsManagedDisk() { s.a.Assert(s.destAccountType, notEquals(), EAccountType.StdManagedDisk(), "Upload is not supported in MD testing yet") s.a.Assert(s.destAccountType, notEquals(), EAccountType.OAuthManagedDisk(), "Upload is not supported in MD testing yet") + s.a.Assert(s.destAccountType, notEquals(), EAccountType.ManagedDiskSnapshot(), "Cannot upload to a MD snapshot") + s.a.Assert(s.destAccountType, notEquals(), EAccountType.ManagedDiskSnapshotOAuth(), "Cannot upload to a MD snapshot") s.a.Assert(true, equals(), s.fromTo.From() == common.ELocation.Blob() || s.fromTo.From() == common.ELocation.BlobFS()) } @@ -380,10 +382,10 @@ func (s *scenario) resumeAzCopy(logDir string) { r := newTestRunner() if sas := s.state.source.getSAS(); s.GetTestFiles().sourcePublic == nil && sas != "" { - r.flags["source-sas"] = sas + r.flags["source-sas"] = strings.TrimPrefix(sas, "?") } if sas := s.state.dest.getSAS(); sas != "" { - r.flags["destination-sas"] = sas + r.flags["destination-sas"] = strings.TrimPrefix(sas, "?") } // use the general-purpose "after start" mechanism, provided by execDebuggableWithOutput, @@ -471,10 +473,8 @@ func (s *scenario) getTransferInfo() (srcRoot string, dstRoot string, expectFold srcBase := filepath.Base(srcRoot) srcRootURL, err := url.Parse(srcRoot) if err == nil { - snapshotID := srcRootURL.Query().Get("sharesnapshot") - if snapshotID != "" { - srcBase = filepath.Base(strings.TrimSuffix(srcRoot, "?sharesnapshot="+snapshotID)) - } + srcBase, _ = trimBaseSnapshotDetails(s.a, srcRootURL, s.fromTo.From(), s.srcAccountType) + srcBase = filepath.Base(srcBase) } // do we expect folder transfers diff --git a/e2etest/managedDisks.go b/e2etest/managedDisks.go index 2be956e7d..145cd96d7 100644 --- a/e2etest/managedDisks.go +++ b/e2etest/managedDisks.go @@ -4,6 +4,7 @@ import ( "bytes" "encoding/json" "fmt" + "github.com/Azure/azure-storage-azcopy/v10/common" "io" "net/http" "net/url" @@ -17,7 +18,12 @@ func (config *ManagedDiskConfig) GetMDURL() (*url.URL, error) { return nil, fmt.Errorf("one or more important details are missing in the config") } - uri := fmt.Sprintf("https://management.azure.com/subscriptions/%s/resourceGroups/%s/providers/Microsoft.Compute/disks/%s?api-version=2022-03-02", config.SubscriptionID, config.ResourceGroupName, config.DiskName) + // the API is the same, but the provider is different + uriFormat := common.Iff(config.isSnapshot, + "https://management.azure.com/subscriptions/%s/resourceGroups/%s/providers/Microsoft.Compute/snapshots/%s?api-version=2023-04-02", + "https://management.azure.com/subscriptions/%s/resourceGroups/%s/providers/Microsoft.Compute/disks/%s?api-version=2023-04-02") + + uri := fmt.Sprintf(uriFormat, config.SubscriptionID, config.ResourceGroupName, config.DiskName) out, err := url.Parse(uri) if err != nil { return nil, fmt.Errorf("failed to parse URI (maybe some detail of the config was formatted invalid?)") diff --git a/e2etest/newe2e_account_registry.go b/e2etest/newe2e_account_registry.go index 87d1a368c..8bc2af8a4 100644 --- a/e2etest/newe2e_account_registry.go +++ b/e2etest/newe2e_account_registry.go @@ -55,7 +55,7 @@ func CreateAccount(a Asserter, accountType AccountType, options *CreateAccountOp } accountARMDefinition := ARMStorageAccountCreateParams{ - Location: "West US", // todo configurable + Location: "West US 2", // todo configurable } switch accountType { // https://learn.microsoft.com/en-us/azure/storage/common/storage-account-create?tabs=azure-portal#storage-account-type-parameters diff --git a/e2etest/validator.go b/e2etest/validator.go index 9f24397c4..189288fa8 100644 --- a/e2etest/validator.go +++ b/e2etest/validator.go @@ -55,6 +55,35 @@ func fixSlashes(s string, loc common.Location) string { // versionIDRegex is intended to capture variations of the destination version ID. var versionIDRegex = regexp.MustCompile("^\\d{4}-\\d{2}-\\d{2}T\\d{2}[-:]\\d{2}[-:]\\d{2}\\.\\d{7}Z") +func trimBaseSnapshotDetails(c asserter, url *url.URL, location common.Location, acctType AccountType) (trimmed, snapshot string) { + switch { + case location == common.ELocation.File(): + snapshot = url.Query().Get("sharesnapshot") + if snapshot != "" { + query := url.Query() + query.Del("sharesnapshot") + url.RawQuery = query.Encode() + trimmed = url.String() + } else { + trimmed = url.String() + } + case location == common.ELocation.Blob() && acctType.IsManagedDisk(): + snapshot = url.Query().Get("snapshot") + if snapshot != "" { + query := url.Query() + query.Del("snapshot") + url.RawQuery = query.Encode() + trimmed = url.String() + } else { + trimmed = url.String() + } + default: + trimmed = url.String() + } + + return +} + func (Validator) ValidateRemoveTransfer(c asserter, isSrcEncoded bool, isDstEncoded bool, sourcePrefix string, destinationPrefix string, expectedTransfers []*testObject, actualTransfers []common.TransferDetail, statusToTest common.TransferStatus) { // TODO: Think of how to validate files in case of remove @@ -71,10 +100,7 @@ func (Validator) ValidateCopyTransfersAreScheduled(s *scenario, isSrcEncoded boo // i.e. source is a URL srcPrefixURL, err := url.Parse(sourcePrefix) if err == nil { - snapshotID = srcPrefixURL.Query().Get("sharesnapshot") - if snapshotID != "" { - sourcePrefix = strings.TrimSuffix(sourcePrefix, "?sharesnapshot="+snapshotID) - } + sourcePrefix, snapshotID = trimBaseSnapshotDetails(c, srcPrefixURL, s.fromTo.From(), s.srcAccountType) } } @@ -100,7 +126,9 @@ func (Validator) ValidateCopyTransfersAreScheduled(s *scenario, isSrcEncoded boo for _, transfer := range actualTransfers { if snapshotID != "" { c.Assert(strings.Contains(transfer.Src, snapshotID), equals(), true) - transfer.Src = strings.TrimSuffix(transfer.Src, "?sharesnapshot="+snapshotID) + uri, err := url.Parse(transfer.Src) + c.AssertNoErr(err, "url must parse, sanity check") + transfer.Src, _ = trimBaseSnapshotDetails(c, uri, s.fromTo.From(), s.srcAccountType) } srcRelativeFilePath := strings.Trim(strings.TrimPrefix(makeSlashesComparable(transfer.Src), sourcePrefix), "/") diff --git a/e2etest/zt_managed_disks_test.go b/e2etest/zt_managed_disks_test.go index 286037a51..6139cda52 100644 --- a/e2etest/zt_managed_disks_test.go +++ b/e2etest/zt_managed_disks_test.go @@ -22,13 +22,20 @@ package e2etest import ( "github.com/Azure/azure-storage-azcopy/v10/common" + "runtime" "testing" + "time" ) // Purpose: Tests for the special cases that relate to moving managed disks (default local VHD to page blob; special handling for // md- and md-impex URLs. func TestManagedDisks_NoOAuthRequired(t *testing.T) { + if runtime.GOOS != "linux" { + t.Skip("Limit runs to Linux so no simultaneous runs occur") + return + } + RunScenarios( t, eOperation.Copy(), @@ -51,26 +58,92 @@ func TestManagedDisks_NoOAuthRequired(t *testing.T) { ) } +func TestManagedDisks_Snapshot(t *testing.T) { + if runtime.GOOS != "linux" { + t.Skip("Limit runs to Linux so no simultaneous runs occur") + return + } + + RunScenarios( + t, + eOperation.Copy(), + eTestFromTo.Other(common.EFromTo.BlobLocal(), common.EFromTo.BlobBlob()), // It's relevant to test blobblob since this interfaces with x-ms-copysourceauthorization + eValidate.Auto(), + anonymousAuthOnly, + anonymousAuthOnly, + params{ + disableParallelTesting: true, + }, + nil, + testFiles{ + shouldTransfer: []interface{}{ + "", + }, + }, // Managed disks will always have a transfer target of "" + EAccountType.Standard(), + EAccountType.ManagedDiskSnapshot(), + "", + ) +} + +func TestManagedDisks_SnapshotOAuth(t *testing.T) { + if runtime.GOOS != "linux" { + t.Skip("Limit runs to Linux so no simultaneous runs occur") + return + } + + RunScenarios( + t, + eOperation.Copy(), + eTestFromTo.Other(common.EFromTo.BlobLocal(), common.EFromTo.BlobBlob()), // It's relevant to test blobblob since this interfaces with x-ms-copysourceauthorization + eValidate.Auto(), + []common.CredentialType{common.ECredentialType.MDOAuthToken()}, + []common.CredentialType{common.ECredentialType.Anonymous(), common.ECredentialType.OAuthToken()}, + params{ + disableParallelTesting: true, + }, + nil, + testFiles{ + shouldTransfer: []interface{}{ + "", + }, + }, // Managed disks will always have a transfer target of "" + EAccountType.Standard(), + EAccountType.ManagedDiskSnapshotOAuth(), + "", + ) +} + // Service issue causes occasional flickers in feature functionality; enough that testing is problematic. Temporarily disabled until issue is resolved. -// func TestManagedDisks_OAuthRequired(t *testing.T) { -// RunScenarios( -// t, -// eOperation.Copy(), -// eTestFromTo.Other(common.EFromTo.BlobLocal(), common.EFromTo.BlobBlob()), // It's relevant to test blobblob since this interfaces with x-ms-copysourceauthorization -// eValidate.Auto(), -// []common.CredentialType{common.ECredentialType.MDOAuthToken()}, -// []common.CredentialType{common.ECredentialType.Anonymous(), common.ECredentialType.OAuthToken()}, -// params{ -// disableParallelTesting: true, // testing is implemented with a single managed disk -// }, -// nil, -// testFiles{ -// shouldTransfer: []interface{}{ -// "", -// }, -// }, // Managed disks will always have a transfer target of "" -// EAccountType.Standard(), -// EAccountType.OAuthManagedDisk(), -// "", -// ) -// } +func TestManagedDisks_OAuthRequired(t *testing.T) { + if runtime.GOOS != "linux" { + t.Skip("Limit runs to Linux so no simultaneous runs occur") + return + } + + RunScenarios( + t, + eOperation.Copy(), + eTestFromTo.Other(common.EFromTo.BlobLocal(), common.EFromTo.BlobBlob()), // It's relevant to test blobblob since this interfaces with x-ms-copysourceauthorization + eValidate.Auto(), + []common.CredentialType{common.ECredentialType.MDOAuthToken()}, + []common.CredentialType{common.ECredentialType.Anonymous(), common.ECredentialType.OAuthToken()}, + params{ + disableParallelTesting: true, // testing is implemented with a single managed disk + }, + &hooks{ + beforeRunJob: func(h hookHelper) { + // try giving the service some time to think + time.Sleep(time.Second * 30) + }, + }, + testFiles{ + shouldTransfer: []interface{}{ + "", + }, + }, // Managed disks will always have a transfer target of "" + EAccountType.Standard(), + EAccountType.OAuthManagedDisk(), + "", + ) +} diff --git a/ste/testJobPartTransferManager_test.go b/ste/testJobPartTransferManager_test.go index 61b675db2..a47f82bc9 100644 --- a/ste/testJobPartTransferManager_test.go +++ b/ste/testJobPartTransferManager_test.go @@ -57,9 +57,10 @@ func (t *testJobPartTransferManager) SrcServiceClient() *common.ServiceClient { AllowTrailingDot: true, } } + client, _ := common.GetServiceClientForLocation( t.fromTo.From(), - t.info.Source, + common.ResourceString{Value: t.info.Source}, t.S2SSourceCredentialInfo().CredentialType, t.S2SSourceCredentialInfo().OAuthTokenInfo.TokenCredential, &options, @@ -77,9 +78,10 @@ func (t *testJobPartTransferManager) DstServiceClient() *common.ServiceClient { AllowSourceTrailingDot: true, } } + client, _ := common.GetServiceClientForLocation( t.fromTo.To(), - t.info.Destination, + common.ResourceString{Value: t.info.Destination}, t.CredentialInfo().CredentialType, t.CredentialInfo().OAuthTokenInfo.TokenCredential, &options,