diff --git a/sdks/python/apache_beam/transforms/external.py b/sdks/python/apache_beam/transforms/external.py index f13c29271d730..3b1323ccd24f5 100644 --- a/sdks/python/apache_beam/transforms/external.py +++ b/sdks/python/apache_beam/transforms/external.py @@ -704,7 +704,7 @@ def _expand_jars(jar): return glob.glob(jar) elif isinstance(jar, str) and (jar.startswith('http://') or jar.startswith('https://')): - return [jar] + return [subprocess_server.JavaJarServer.local_jar(jar)] else: # If the input JAR is not a local glob, nor an http/https URL, then # we assume that it's a gradle-style Java artifact in Maven Central, @@ -716,8 +716,9 @@ def _expand_jars(jar): # a JAR path, we still choose to include it in the path. logging.warning('Unable to parse %s into group:artifact:version.', jar) return [jar] - path = subprocess_server.JavaJarServer.path_to_maven_jar( - artifact_id, group_id, version) + path = subprocess_server.JavaJarServer.local_jar( + subprocess_server.JavaJarServer.path_to_maven_jar( + artifact_id, group_id, version)) return [path] def _default_args(self): diff --git a/sdks/python/apache_beam/transforms/external_test.py b/sdks/python/apache_beam/transforms/external_test.py index fdb4dd00e66fa..36a45a62e2e4a 100644 --- a/sdks/python/apache_beam/transforms/external_test.py +++ b/sdks/python/apache_beam/transforms/external_test.py @@ -26,6 +26,8 @@ import typing import unittest +import mock + import apache_beam as beam from apache_beam import Pipeline from apache_beam.coders import RowCoder @@ -47,6 +49,7 @@ from apache_beam.typehints import typehints from apache_beam.typehints.native_type_compatibility import convert_to_beam_type from apache_beam.utils import proto_utils +from apache_beam.utils.subprocess_server import JavaJarServer # Protect against environments where apitools library is not available. # pylint: disable=wrong-import-order, wrong-import-position @@ -570,27 +573,68 @@ def test_classpath(self): finally: os.chdir(oldwd) - def test_maven_central_classpath(self): + @mock.patch.object(JavaJarServer, 'local_jar') + def test_classpath_with_url(self, local_jar): + def _side_effect_fn(path): + return path[path.rindex('/') + 1:] + + local_jar.side_effect = _side_effect_fn + with tempfile.TemporaryDirectory() as temp_dir: try: # Avoid having to prefix everything in our test strings. oldwd = os.getcwd() os.chdir(temp_dir) - # Touch some files for globing. - with open('a1.jar', 'w') as _: - pass service = JavaJarExpansionService( - 'main.jar', - classpath=['a*.jar', 'b.jar', 'org.postgresql:postgresql:42.2.16']) + 'main.jar', classpath=['https://dummy_path/dummyjar.jar']) + + self.assertEqual( + service._default_args(), + ['{{PORT}}', '--filesToStage=main.jar,dummyjar.jar']) + finally: + os.chdir(oldwd) + + @mock.patch.object(JavaJarServer, 'local_jar') + def test_classpath_with_gradle_artifact(self, local_jar): + def _side_effect_fn(path): + return path[path.rindex('/') + 1:] + + local_jar.side_effect = _side_effect_fn + + with tempfile.TemporaryDirectory() as temp_dir: + try: + # Avoid having to prefix everything in our test strings. + oldwd = os.getcwd() + os.chdir(temp_dir) + + service = JavaJarExpansionService( + 'main.jar', classpath=['dummy_group:dummy_artifact:dummy_version']) + self.assertEqual( service._default_args(), [ '{{PORT}}', - '--filesToStage=main.jar,a1.jar,b.jar,' - 'https://repo.maven.apache.org/maven2/org/' - 'postgresql/postgresql/42.2.16/postgresql-42.2.16.jar' + '--filesToStage=main.jar,dummy_artifact-dummy_version.jar' ]) + finally: + os.chdir(oldwd) + + def test_classpath_with_glob(self): + with tempfile.TemporaryDirectory() as temp_dir: + try: + # Avoid having to prefix everything in our test strings. + oldwd = os.getcwd() + os.chdir(temp_dir) + # Touch some files for globing. + with open('a1.jar', 'w') as _: + pass + + service = JavaJarExpansionService( + 'main.jar', classpath=['a*.jar', 'b.jar']) + self.assertEqual( + service._default_args(), + ['{{PORT}}', '--filesToStage=main.jar,a1.jar,b.jar']) finally: os.chdir(oldwd)