Skip to content

Commit

Permalink
Use common download logic for agent downloads (#2682)
Browse files Browse the repository at this point in the history
* Use common download logic for agent downloads

* rename method

* add unit test

Co-authored-by: narrieta <narrieta>
  • Loading branch information
narrieta committed Oct 18, 2022
1 parent faa8b14 commit 527443c
Show file tree
Hide file tree
Showing 13 changed files with 248 additions and 567 deletions.
7 changes: 0 additions & 7 deletions azurelinuxagent/common/protocol/goal_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,13 +152,6 @@ def _fetch_manifest(self, manifest_type, name, uris):
except Exception as e:
raise ProtocolError("Failed to retrieve {0} manifest. Error: {1}".format(manifest_type, ustr(e)))

def download_extension(self, uris, destination, on_downloaded=lambda: True):
"""
This is a convenience method that wraps WireClient.download_extension(), but adds the required 'use_verify_header' parameter.
"""
is_fast_track = self.extensions_goal_state.source == GoalStateSource.FastTrack
self._wire_client.download_extension(uris, destination, use_verify_header=is_fast_track, on_downloaded=on_downloaded)

@staticmethod
def update_host_plugin_headers(wire_client):
"""
Expand Down
54 changes: 40 additions & 14 deletions azurelinuxagent/common/protocol/wire.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,9 @@
import json
import os
import random
import shutil
import time
import zipfile

from collections import defaultdict
from datetime import datetime, timedelta
Expand Down Expand Up @@ -604,35 +606,39 @@ def hgap_download(uri):

return self._download_with_fallback_channel(download_type, uris, direct_download=direct_download, hgap_download=hgap_download)

def download_extension(self, uris, destination, use_verify_header, on_downloaded=lambda: True):
def download_zip_package(self, package_type, uris, target_file, target_directory, use_verify_header):
"""
Walks the given list of 'uris' issuing HTTP GET requests and saves the content of the first successful request to 'destination'.
Downloads the ZIP package specified in 'uris' (which is a list of alternate locations for the ZIP), saving it to 'target_file' and then expanding
its contents to 'target_directory'. Deletes the target file after it has been expanded.
When the download is successful, this method invokes the 'on_downloaded' callback function, which can be used to process the results of the download.
on_downloaded() should return True on success and False on failure (it should not raise any exceptions); ff the return value is False, the download
is considered a failure and the next URI is tried.
The 'package_type' is only used in log messages and has no other semantics. It should specify the contents of the ZIP, e.g. "extension package"
or "agent package"
The 'use_verify_header' parameter indicates whether the verify header should be added when using the extensionArtifact API of the HostGAPlugin.
"""
host_ga_plugin = self.get_host_plugin()

direct_download = lambda uri: self.stream(uri, destination, headers=None, use_proxy=True)
direct_download = lambda uri: self.stream(uri, target_file, headers=None, use_proxy=True)

def hgap_download(uri):
request_uri, request_headers = host_ga_plugin.get_artifact_request(uri, use_verify_header=use_verify_header, artifact_manifest_url=host_ga_plugin.manifest_uri)
return self.stream(request_uri, destination, headers=request_headers, use_proxy=False)
return self.stream(request_uri, target_file, headers=request_headers, use_proxy=False)

on_downloaded = lambda: WireClient._try_expand_zip_package(package_type, target_file, target_directory)

self._download_with_fallback_channel("extension package", uris, direct_download=direct_download, hgap_download=hgap_download, on_downloaded=on_downloaded)
self._download_with_fallback_channel(package_type, uris, direct_download=direct_download, hgap_download=hgap_download, on_downloaded=on_downloaded)

def _download_with_fallback_channel(self, download_type, uris, direct_download, hgap_download, on_downloaded=lambda: True):
def _download_with_fallback_channel(self, download_type, uris, direct_download, hgap_download, on_downloaded=None):
"""
Walks the given list of 'uris' issuing HTTP GET requests, attempting to download the content of each URI. The download is done using both the default and
the fallback channels, until one of them succeeds. The 'direct_download' and 'hgap_download' functions define the logic to do direct calls to the URI or
to use the HostGAPlugin as a proxy for the download. Initially the default channel is the direct download and the fallback channel is the HostGAPlugin,
but the default can be depending on the success/failure of each channel (see _download_using_appropriate_channel() for the logic to do this).
The 'download_type' is added to any log messages produced by this method; it should describe the type of content of the given URIs
(e.g. "manifest", "extension package", etc).
(e.g. "manifest", "extension package, "agent package", etc).
When the download is successful download_extension() invokes the 'on_downloaded' function, which can be used to process the results of the download. This
When the download is successful, _download_with_fallback_channel invokes the 'on_downloaded' function, which can be used to process the results of the download. This
function should return True on success, and False on failure (it should not raise any exceptions). If the return value is False, the download is considered
a failure and the next URI is tried.
Expand All @@ -641,7 +647,7 @@ def _download_with_fallback_channel(self, download_type, uris, direct_download,
This method enforces a timeout (_DOWNLOAD_TIMEOUT) on the download and raises an exception if the limit is exceeded.
"""
logger.verbose("Downloading {0}", download_type)
logger.info("Downloading {0}", download_type)
start_time = datetime.now()

uris_shuffled = uris
Expand All @@ -658,14 +664,34 @@ def _download_with_fallback_channel(self, download_type, uris, direct_download,
# Disable W0640: OK to use uri in a lambda within the loop's body
response = self._download_using_appropriate_channel(lambda: direct_download(uri), lambda: hgap_download(uri)) # pylint: disable=W0640

if on_downloaded():
return uri, response
if on_downloaded is not None:
on_downloaded()

return uri, response
except Exception as exception:
most_recent_error = exception

raise ExtensionDownloadError("Failed to download {0} from all URIs. Last error: {1}".format(download_type, ustr(most_recent_error)), code=ExtensionErrorCodes.PluginManifestDownloadError)

@staticmethod
def _try_expand_zip_package(package_type, target_file, target_directory):
logger.info("Unzipping {0}: {1}", package_type, target_file)
try:
zipfile.ZipFile(target_file).extractall(target_directory)
except Exception as exception:
logger.error("Error while unzipping {0}: {1}", package_type, ustr(exception))
if os.path.exists(target_directory):
try:
shutil.rmtree(target_directory)
except Exception as exception:
logger.warn("Cannot delete {0}: {1}", target_directory, ustr(exception))
raise
finally:
try:
os.remove(target_file)
except Exception as exception:
logger.warn("Cannot delete {0}: {1}", target_file, ustr(exception))

def stream(self, uri, destination, headers=None, use_proxy=None):
"""
Downloads the content of the given 'uri' and saves it to the 'destination' file.
Expand Down
15 changes: 9 additions & 6 deletions azurelinuxagent/ga/exthandlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@
ExtensionOperationError, ExtensionUpdateError, ProtocolError, ProtocolNotFoundError, ExtensionsGoalStateError, \
GoalStateAggregateStatusCodes, MultiConfigExtensionEnableError
from azurelinuxagent.common.future import ustr, is_file_not_found_error
from azurelinuxagent.common.protocol.extensions_goal_state import GoalStateSource
from azurelinuxagent.common.protocol.restapi import ExtensionStatus, ExtensionSubStatus, Extension, ExtHandlerStatus, \
VMStatus, GoalStateAggregateStatus, ExtensionState, ExtensionRequestedState, ExtensionSettings
from azurelinuxagent.common.utils import textutil
Expand Down Expand Up @@ -1252,21 +1253,23 @@ def download(self):
if self.pkg is None or self.pkg.uris is None or len(self.pkg.uris) == 0:
raise ExtensionDownloadError("No package uri found")

destination = os.path.join(conf.get_lib_dir(), self.get_extension_package_zipfile_name())
package_file = os.path.join(conf.get_lib_dir(), self.get_extension_package_zipfile_name())

package_exists = False
if os.path.exists(destination):
self.logger.info("Using existing extension package: {0}", destination)
if self._unzip_extension_package(destination, self.get_base_dir()):
if os.path.exists(package_file):
self.logger.info("Using existing extension package: {0}", package_file)
if self._unzip_extension_package(package_file, self.get_base_dir()):
package_exists = True
else:
self.logger.info("The existing extension package is invalid, will ignore it.")

if not package_exists:
self.protocol.get_goal_state().download_extension(self.pkg.uris, destination, on_downloaded=lambda: self._unzip_extension_package(destination, self.get_base_dir()))
is_fast_track_goal_state = self.protocol.get_goal_state().extensions_goal_state.source == GoalStateSource.FastTrack
self.protocol.client.download_zip_package("extension package", self.pkg.uris, package_file, self.get_base_dir(), use_verify_header=is_fast_track_goal_state)
self.report_event(message="Download succeeded", duration=elapsed_milliseconds(begin_utc))

self.pkg_file = destination
self.pkg_file = package_file


def ensure_consistent_data_for_mc(self):
# If CRP expects Handler to support MC, ensure the HandlerManifest also reflects that.
Expand Down
Loading

0 comments on commit 527443c

Please sign in to comment.