From 729e6d144655bd26e6453dcc01a7e6f0d5c8f50e Mon Sep 17 00:00:00 2001 From: Abhishek Kumar Date: Wed, 19 Jul 2023 15:18:53 +0530 Subject: [PATCH] userdata: fix append scenarios (#7741) Fixes case of appending userdata when both template and vm data are either shellscript or cloudconfig Fixes error when appending gzip userdata Fixes case when userdata manual text from VM is not getting decoded-encoded correctly. Fixes case of appending multipart data when both template and vm data contain same format types. Refactor - moved validateUserData method to UserDataManager class Refactor userdata test to check resultant multipart userdata thoroughly Signed-off-by: Abhishek Kumar --- .../cloudstack/userdata/UserDataManager.java | 5 +- .../configuration/ConfigurationManager.java | 5 + .../userdata/CloudInitUserDataProvider.java | 134 ++++++++++++---- .../CloudInitUserDataProviderTest.java | 145 +++++++++++++----- engine/userdata/pom.xml | 6 + .../userdata/UserDataManagerImpl.java | 76 +++++++-- .../cloudstack/userdata/UserDataProvider.java | 4 +- .../userdata/UserDataManagerImplTest.java | 59 +++++++ .../ConfigurationManagerImpl.java | 3 - .../cloud/server/ManagementServerImpl.java | 73 ++------- .../main/java/com/cloud/vm/UserVmManager.java | 2 - .../java/com/cloud/vm/UserVmManagerImpl.java | 62 +------- .../server/ManagementServerImplTest.java | 70 +++++---- .../com/cloud/vm/UserVmManagerImplTest.java | 10 +- .../java/com/cloud/vm/UserVmManagerTest.java | 25 --- .../smoke/test_register_userdata.py | 39 +++-- 16 files changed, 443 insertions(+), 275 deletions(-) create mode 100644 engine/userdata/src/test/java/org/apache/cloudstack/userdata/UserDataManagerImplTest.java diff --git a/api/src/main/java/org/apache/cloudstack/userdata/UserDataManager.java b/api/src/main/java/org/apache/cloudstack/userdata/UserDataManager.java index 2fc3acd45d15..4dfcd0a7de1b 100644 --- a/api/src/main/java/org/apache/cloudstack/userdata/UserDataManager.java +++ b/api/src/main/java/org/apache/cloudstack/userdata/UserDataManager.java @@ -16,9 +16,12 @@ // under the License. package org.apache.cloudstack.userdata; -import com.cloud.utils.component.Manager; +import org.apache.cloudstack.api.BaseCmd; import org.apache.cloudstack.framework.config.Configurable; +import com.cloud.utils.component.Manager; + public interface UserDataManager extends Manager, Configurable { String concatenateUserData(String userdata1, String userdata2, String userdataProvider); + String validateUserData(String userData, BaseCmd.HTTPMethod httpmethod); } diff --git a/engine/components-api/src/main/java/com/cloud/configuration/ConfigurationManager.java b/engine/components-api/src/main/java/com/cloud/configuration/ConfigurationManager.java index c5caa312b584..5343fb632b54 100644 --- a/engine/components-api/src/main/java/com/cloud/configuration/ConfigurationManager.java +++ b/engine/components-api/src/main/java/com/cloud/configuration/ConfigurationManager.java @@ -20,6 +20,7 @@ import java.util.Map; import java.util.Set; +import org.apache.cloudstack.framework.config.ConfigKey; import org.apache.cloudstack.framework.config.impl.ConfigurationSubGroupVO; import com.cloud.dc.ClusterVO; @@ -59,6 +60,10 @@ public interface ConfigurationManager { public static final String MESSAGE_CREATE_VLAN_IP_RANGE_EVENT = "Message.CreateVlanIpRange.Event"; public static final String MESSAGE_DELETE_VLAN_IP_RANGE_EVENT = "Message.DeleteVlanIpRange.Event"; + static final String VM_USERDATA_MAX_LENGTH_STRING = "vm.userdata.max.length"; + static final ConfigKey VM_USERDATA_MAX_LENGTH = new ConfigKey<>("Advanced", Integer.class, VM_USERDATA_MAX_LENGTH_STRING, "32768", + "Max length of vm userdata after base64 decoding. Default is 32768 and maximum is 1048576", true); + /** * @param offering * @return diff --git a/engine/userdata/cloud-init/src/main/java/org/apache/cloudstack/userdata/CloudInitUserDataProvider.java b/engine/userdata/cloud-init/src/main/java/org/apache/cloudstack/userdata/CloudInitUserDataProvider.java index c61f37a18966..65996f181a9c 100644 --- a/engine/userdata/cloud-init/src/main/java/org/apache/cloudstack/userdata/CloudInitUserDataProvider.java +++ b/engine/userdata/cloud-init/src/main/java/org/apache/cloudstack/userdata/CloudInitUserDataProvider.java @@ -19,7 +19,7 @@ import java.io.ByteArrayInputStream; import java.io.ByteArrayOutputStream; import java.io.IOException; -import java.nio.charset.StandardCharsets; +import java.io.InputStream; import java.util.Arrays; import java.util.List; import java.util.Map; @@ -35,12 +35,14 @@ import javax.mail.internet.MimeMessage; import javax.mail.internet.MimeMultipart; +import org.apache.commons.codec.binary.Base64; import org.apache.commons.collections.CollectionUtils; import org.apache.commons.lang3.StringUtils; import org.apache.log4j.Logger; import com.cloud.utils.component.AdapterBase; import com.cloud.utils.exception.CloudRuntimeException; +import com.sun.mail.util.BASE64DecoderStream; public class CloudInitUserDataProvider extends AdapterBase implements UserDataProvider { @@ -69,11 +71,11 @@ public String getName() { return "cloud-init"; } - protected boolean isGZipped(String userdata) { - if (StringUtils.isEmpty(userdata)) { + protected boolean isGZipped(String encodedUserdata) { + if (StringUtils.isEmpty(encodedUserdata)) { return false; } - byte[] data = userdata.getBytes(StandardCharsets.ISO_8859_1); + byte[] data = Base64.decodeBase64(encodedUserdata); if (data.length < 2) { return false; } @@ -82,9 +84,6 @@ protected boolean isGZipped(String userdata) { } protected String extractUserDataHeader(String userdata) { - if (isGZipped(userdata)) { - throw new CloudRuntimeException("Gzipped user data can not be used together with other user data formats"); - } List lines = Arrays.stream(userdata.split("\n")) .filter(x -> (x.startsWith("#") && !x.startsWith("##")) || (x.startsWith("Content-Type:"))) .collect(Collectors.toList()); @@ -131,7 +130,7 @@ protected FormatType getUserDataFormatType(String userdata) { private String getContentType(String userData, FormatType formatType) throws MessagingException { if (formatType == FormatType.MIME) { - MimeMessage msg = new MimeMessage(session, new ByteArrayInputStream(userData.getBytes())); + NoIdMimeMessage msg = new NoIdMimeMessage(session, new ByteArrayInputStream(userData.getBytes())); return msg.getContentType(); } if (!formatContentTypeMap.containsKey(formatType)) { @@ -141,15 +140,35 @@ private String getContentType(String userData, FormatType formatType) throws Mes return formatContentTypeMap.get(formatType); } - protected MimeBodyPart generateBodyPartMIMEMessage(String userData, FormatType formatType) throws MessagingException { + protected String getBodyPartContentAsString(BodyPart bodyPart) throws MessagingException, IOException { + Object content = bodyPart.getContent(); + if (content instanceof BASE64DecoderStream) { + return new String(((BASE64DecoderStream)bodyPart.getContent()).readAllBytes()); + } else if (content instanceof ByteArrayInputStream) { + return new String(((ByteArrayInputStream)bodyPart.getContent()).readAllBytes()); + } else if (content instanceof String) { + return (String)bodyPart.getContent(); + } + throw new CloudRuntimeException(String.format("Failed to get content for multipart data with content type: %s", getBodyPartContentType(bodyPart))); + } + + private String getBodyPartContentType(BodyPart bodyPart) throws MessagingException { + String contentType = StringUtils.defaultString(bodyPart.getDataHandler().getContentType(), bodyPart.getContentType()); + return contentType.contains(";") ? contentType.substring(0, contentType.indexOf(';')) : contentType; + } + + protected MimeBodyPart generateBodyPartMimeMessage(String userData, String contentType) throws MessagingException { MimeBodyPart bodyPart = new MimeBodyPart(); - String contentType = getContentType(userData, formatType); bodyPart.setContent(userData, contentType); bodyPart.addHeader("Content-Transfer-Encoding", "base64"); return bodyPart; } - private Multipart getMessageContent(MimeMessage message) { + protected MimeBodyPart generateBodyPartMimeMessage(String userData, FormatType formatType) throws MessagingException { + return generateBodyPartMimeMessage(userData, getContentType(userData, formatType)); + } + + private Multipart getMessageContent(NoIdMimeMessage message) { Multipart messageContent; try { messageContent = (MimeMultipart) message.getContent(); @@ -159,40 +178,83 @@ private Multipart getMessageContent(MimeMessage message) { return messageContent; } - private void addBodyPartsToMessageContentFromUserDataContent(Multipart messageContent, - MimeMessage msgFromUserdata) throws MessagingException, IOException { - Multipart msgFromUserdataParts = (MimeMultipart) msgFromUserdata.getContent(); - int count = msgFromUserdataParts.getCount(); - int i = 0; - while (i < count) { - BodyPart bodyPart = msgFromUserdataParts.getBodyPart(0); - messageContent.addBodyPart(bodyPart); - i++; + private void addBodyPartToMultipart(Multipart existingMultipart, MimeBodyPart bodyPart) throws MessagingException, IOException { + boolean added = false; + final int existingCount = existingMultipart.getCount(); + for (int j = 0; j < existingCount; ++j) { + MimeBodyPart existingBodyPart = (MimeBodyPart)existingMultipart.getBodyPart(j); + String existingContentType = getBodyPartContentType(existingBodyPart); + String newContentType = getBodyPartContentType(bodyPart); + if (existingContentType.equals(newContentType)) { + String existingContent = getBodyPartContentAsString(existingBodyPart); + String newContent = getBodyPartContentAsString(bodyPart); + // generating a combined content MimeBodyPart to replace + MimeBodyPart combinedBodyPart = generateBodyPartMimeMessage( + simpleAppendSameFormatTypeUserData(existingContent, newContent), existingContentType); + existingMultipart.removeBodyPart(j); + existingMultipart.addBodyPart(combinedBodyPart, j); + added = true; + break; + } + } + if (!added) { + existingMultipart.addBodyPart(bodyPart); + } + } + + private void addBodyPartsToMessageContentFromUserDataContent(Multipart existingMultipart, + NoIdMimeMessage msgFromUserdata) throws MessagingException, IOException { + MimeMultipart newMultipart = (MimeMultipart)msgFromUserdata.getContent(); + final int existingCount = existingMultipart.getCount(); + final int newCount = newMultipart.getCount(); + for (int i = 0; i < newCount; ++i) { + BodyPart bodyPart = newMultipart.getBodyPart(i); + if (existingCount == 0) { + existingMultipart.addBodyPart(bodyPart); + continue; + } + addBodyPartToMultipart(existingMultipart, (MimeBodyPart)bodyPart); } } - private MimeMessage createMultipartMessageAddingUserdata(String userData, FormatType formatType, - MimeMessage message) throws MessagingException, IOException { - MimeMessage newMessage = new MimeMessage(session); + private NoIdMimeMessage createMultipartMessageAddingUserdata(String userData, FormatType formatType, + NoIdMimeMessage message) throws MessagingException, IOException { + NoIdMimeMessage newMessage = new NoIdMimeMessage(session); Multipart messageContent = getMessageContent(message); if (formatType == FormatType.MIME) { - MimeMessage msgFromUserdata = new MimeMessage(session, new ByteArrayInputStream(userData.getBytes())); + NoIdMimeMessage msgFromUserdata = new NoIdMimeMessage(session, new ByteArrayInputStream(userData.getBytes())); addBodyPartsToMessageContentFromUserDataContent(messageContent, msgFromUserdata); } else { - MimeBodyPart part = generateBodyPartMIMEMessage(userData, formatType); - messageContent.addBodyPart(part); + MimeBodyPart part = generateBodyPartMimeMessage(userData, formatType); + addBodyPartToMultipart(messageContent, part); } newMessage.setContent(messageContent); return newMessage; } + private String simpleAppendSameFormatTypeUserData(String userData1, String userData2) { + return String.format("%s\n\n%s", userData1, userData2.substring(userData2.indexOf('\n')+1)); + } + + private void checkGzipAppend(String encodedUserData1, String encodedUserData2) { + if (isGZipped(encodedUserData1) || isGZipped(encodedUserData2)) { + throw new CloudRuntimeException("Gzipped user data can not be used together with other user data formats"); + } + } + @Override - public String appendUserData(String userData1, String userData2) { + public String appendUserData(String encodedUserData1, String encodedUserData2) { try { + checkGzipAppend(encodedUserData1, encodedUserData2); + String userData1 = new String(Base64.decodeBase64(encodedUserData1)); + String userData2 = new String(Base64.decodeBase64(encodedUserData2)); FormatType formatType1 = getUserDataFormatType(userData1); FormatType formatType2 = getUserDataFormatType(userData2); - MimeMessage message = new MimeMessage(session); + if (formatType1.equals(formatType2) && List.of(FormatType.CLOUD_CONFIG, FormatType.BASH_SCRIPT).contains(formatType1)) { + return simpleAppendSameFormatTypeUserData(userData1, userData2); + } + NoIdMimeMessage message = new NoIdMimeMessage(session); message = createMultipartMessageAddingUserdata(userData1, formatType1, message); message = createMultipartMessageAddingUserdata(userData2, formatType2, message); ByteArrayOutputStream output = new ByteArrayOutputStream(); @@ -205,4 +267,20 @@ public String appendUserData(String userData1, String userData2) { throw new CloudRuntimeException(msg, e); } } + + /* This is a wrapper class just to remove Message-ID header from the resultant + multipart data which may contain server details. + */ + private class NoIdMimeMessage extends MimeMessage { + NoIdMimeMessage (Session session) { + super(session); + } + NoIdMimeMessage (Session session, InputStream is) throws MessagingException { + super(session, is); + } + @Override + protected void updateMessageID() throws MessagingException { + removeHeader("Message-ID"); + } + } } diff --git a/engine/userdata/cloud-init/src/test/java/org/apache/cloudstack/userdata/CloudInitUserDataProviderTest.java b/engine/userdata/cloud-init/src/test/java/org/apache/cloudstack/userdata/CloudInitUserDataProviderTest.java index b91438c5a360..4ca9fb7ebd67 100644 --- a/engine/userdata/cloud-init/src/test/java/org/apache/cloudstack/userdata/CloudInitUserDataProviderTest.java +++ b/engine/userdata/cloud-init/src/test/java/org/apache/cloudstack/userdata/CloudInitUserDataProviderTest.java @@ -16,11 +16,20 @@ // under the License. package org.apache.cloudstack.userdata; +import java.io.ByteArrayInputStream; import java.io.ByteArrayOutputStream; import java.io.IOException; import java.nio.charset.StandardCharsets; +import java.util.Properties; import java.util.zip.GZIPOutputStream; +import javax.mail.BodyPart; +import javax.mail.MessagingException; +import javax.mail.Session; +import javax.mail.internet.MimeMessage; +import javax.mail.internet.MimeMultipart; + +import org.apache.commons.codec.binary.Base64; import org.junit.Assert; import org.junit.Test; @@ -34,6 +43,33 @@ public class CloudInitUserDataProviderTest { "runcmd:\n" + " - echo 'TestVariable {{ ds.meta_data.variable1 }}' >> /tmp/variable\n" + " - echo 'Hostname {{ ds.meta_data.public_hostname }}' > /tmp/hostname"; + private final static String CLOUD_CONFIG_USERDATA1 = "#cloud-config\n" + + "password: atomic\n" + + "chpasswd: { expire: False }\n" + + "ssh_pwauth: True"; + private final static String SHELL_SCRIPT_USERDATA = "#!/bin/bash\n" + + "date > /provisioned"; + private final static String SHELL_SCRIPT_USERDATA1 = "#!/bin/bash\n" + + "mkdir /tmp/test"; + private final static String SINGLE_BODYPART_CLOUDCONFIG_MULTIPART_USERDATA = + "Content-Type: multipart/mixed; boundary=\"//\"\n" + + "MIME-Version: 1.0\n" + + "\n" + + "--//\n" + + "Content-Type: text/cloud-config; charset=\"us-ascii\"\n" + + "MIME-Version: 1.0\n" + + "Content-Transfer-Encoding: 7bit\n" + + "Content-Disposition: attachment; filename=\"cloud-config.txt\"\n" + + "\n" + + "#cloud-config\n" + + "\n" + + "# Upgrade the instance on first boot\n" + + "# (ie run apt-get upgrade)\n" + + "#\n" + + "# Default: false\n" + + "# Aliases: apt_upgrade\n" + + "package_upgrade: true"; + private static final Session session = Session.getDefaultInstance(new Properties()); @Test public void testGetUserDataFormatType() { @@ -54,51 +90,81 @@ public void testGetUserDataFormatTypeInvalidType() { provider.getUserDataFormatType(userdata); } + private MimeMultipart getCheckedMultipartFromMultipartData(String multipartUserData, int count) { + MimeMultipart multipart = null; + Assert.assertTrue(multipartUserData.contains("Content-Type: multipart")); + try { + MimeMessage msgFromUserdata = new MimeMessage(session, + new ByteArrayInputStream(multipartUserData.getBytes())); + multipart = (MimeMultipart)msgFromUserdata.getContent(); + Assert.assertEquals(count, multipart.getCount()); + } catch (MessagingException | IOException e) { + Assert.fail(String.format("Failed with exception, %s", e.getMessage())); + } + return multipart; + } + @Test public void testAppendUserData() { - String templateData = "#cloud-config\n" + - "password: atomic\n" + - "chpasswd: { expire: False }\n" + - "ssh_pwauth: True"; - String vmData = "#!/bin/bash\n" + - "date > /provisioned"; - String multipartUserData = provider.appendUserData(templateData, vmData); - Assert.assertTrue(multipartUserData.contains("Content-Type: multipart")); + String multipartUserData = provider.appendUserData(Base64.encodeBase64String(CLOUD_CONFIG_USERDATA1.getBytes()), + Base64.encodeBase64String(SHELL_SCRIPT_USERDATA.getBytes())); + getCheckedMultipartFromMultipartData(multipartUserData, 2); + } + + @Test + public void testAppendSameShellScriptTypeUserData() { + String result = SHELL_SCRIPT_USERDATA + "\n\n" + + SHELL_SCRIPT_USERDATA1.replace("#!/bin/bash\n", ""); + String appendUserData = provider.appendUserData(Base64.encodeBase64String(SHELL_SCRIPT_USERDATA.getBytes()), + Base64.encodeBase64String(SHELL_SCRIPT_USERDATA1.getBytes())); + Assert.assertEquals(result, appendUserData); + } + + @Test + public void testAppendSameCloudConfigTypeUserData() { + String result = CLOUD_CONFIG_USERDATA + "\n\n" + + CLOUD_CONFIG_USERDATA1.replace("#cloud-config\n", ""); + String appendUserData = provider.appendUserData(Base64.encodeBase64String(CLOUD_CONFIG_USERDATA.getBytes()), + Base64.encodeBase64String(CLOUD_CONFIG_USERDATA1.getBytes())); + Assert.assertEquals(result, appendUserData); } @Test public void testAppendUserDataMIMETemplateData() { - String templateData = "Content-Type: multipart/mixed; boundary=\"//\"\n" + - "MIME-Version: 1.0\n" + - "\n" + - "--//\n" + - "Content-Type: text/cloud-config; charset=\"us-ascii\"\n" + - "MIME-Version: 1.0\n" + - "Content-Transfer-Encoding: 7bit\n" + - "Content-Disposition: attachment; filename=\"cloud-config.txt\"\n" + - "\n" + - "#cloud-config\n" + - "\n" + - "# Upgrade the instance on first boot\n" + - "# (ie run apt-get upgrade)\n" + - "#\n" + - "# Default: false\n" + - "# Aliases: apt_upgrade\n" + - "package_upgrade: true"; - String vmData = "#!/bin/bash\n" + - "date > /provisioned"; - String multipartUserData = provider.appendUserData(templateData, vmData); - Assert.assertTrue(multipartUserData.contains("Content-Type: multipart")); + String multipartUserData = provider.appendUserData( + Base64.encodeBase64String(SINGLE_BODYPART_CLOUDCONFIG_MULTIPART_USERDATA.getBytes()), + Base64.encodeBase64String(SHELL_SCRIPT_USERDATA.getBytes())); + getCheckedMultipartFromMultipartData(multipartUserData, 2); + } + + @Test + public void testAppendUserDataExistingMultipartWithSameType() { + String templateData = provider.appendUserData(Base64.encodeBase64String(CLOUD_CONFIG_USERDATA1.getBytes()), + Base64.encodeBase64String(SHELL_SCRIPT_USERDATA.getBytes())); + String multipartUserData = provider.appendUserData(Base64.encodeBase64String(templateData.getBytes()), + Base64.encodeBase64String(SHELL_SCRIPT_USERDATA1.getBytes())); + String resultantShellScript = SHELL_SCRIPT_USERDATA + "\n\n" + + SHELL_SCRIPT_USERDATA1.replace("#!/bin/bash\n", ""); + MimeMultipart mimeMultipart = getCheckedMultipartFromMultipartData(multipartUserData, 2); + try { + for (int i = 0; i < mimeMultipart.getCount(); ++i) { + BodyPart bodyPart = mimeMultipart.getBodyPart(i); + if (bodyPart.getContentType().startsWith("text/x-shellscript")) { + Assert.assertEquals(resultantShellScript, provider.getBodyPartContentAsString(bodyPart)); + } else if (bodyPart.getContentType().startsWith("text/cloud-config")) { + Assert.assertEquals(CLOUD_CONFIG_USERDATA1, provider.getBodyPartContentAsString(bodyPart)); + } + } + } catch (MessagingException | IOException | CloudRuntimeException e) { + Assert.fail(String.format("Failed with exception, %s", e.getMessage())); + } } @Test(expected = CloudRuntimeException.class) public void testAppendUserDataInvalidUserData() { - String templateData = "password: atomic\n" + - "chpasswd: { expire: False }\n" + - "ssh_pwauth: True"; - String vmData = "#!/bin/bash\n" + - "date > /provisioned"; - provider.appendUserData(templateData, vmData); + String templateData = CLOUD_CONFIG_USERDATA1.replace("#cloud-config\n", ""); + provider.appendUserData(Base64.encodeBase64String(templateData.getBytes()), + Base64.encodeBase64String(SHELL_SCRIPT_USERDATA.getBytes())); } @Test @@ -106,7 +172,7 @@ public void testIsGzippedUserDataWithCloudConfigData() { Assert.assertFalse(provider.isGZipped(CLOUD_CONFIG_USERDATA)); } - private String createGzipDataAsString() throws IOException { + private String createBase64EncodedGzipDataAsString() throws IOException { byte[] input = CLOUD_CONFIG_USERDATA.getBytes(StandardCharsets.ISO_8859_1); ByteArrayOutputStream arrayOutputStream = new ByteArrayOutputStream(); @@ -114,13 +180,13 @@ private String createGzipDataAsString() throws IOException { outputStream.write(input,0, input.length); outputStream.close(); - return arrayOutputStream.toString(StandardCharsets.ISO_8859_1); + return Base64.encodeBase64String(arrayOutputStream.toByteArray()); } @Test public void testIsGzippedUserDataWithValidGzipData() { try { - String gzipped = createGzipDataAsString(); + String gzipped = createBase64EncodedGzipDataAsString(); Assert.assertTrue(provider.isGZipped(gzipped)); } catch (IOException e) { Assert.fail(e.getMessage()); @@ -130,7 +196,8 @@ public void testIsGzippedUserDataWithValidGzipData() { @Test(expected = CloudRuntimeException.class) public void testAppendUserDataWithGzippedData() { try { - provider.appendUserData(CLOUD_CONFIG_USERDATA, createGzipDataAsString()); + provider.appendUserData(Base64.encodeBase64String(CLOUD_CONFIG_USERDATA.getBytes()), + createBase64EncodedGzipDataAsString()); Assert.fail("Gzipped data shouldn't be appended with other data"); } catch (IOException e) { Assert.fail("Exception encountered: " + e.getMessage()); diff --git a/engine/userdata/pom.xml b/engine/userdata/pom.xml index 2e00ebd97867..75475b2af183 100644 --- a/engine/userdata/pom.xml +++ b/engine/userdata/pom.xml @@ -43,5 +43,11 @@ activation 1.1.1 + + org.apache.cloudstack + cloud-engine-components-api + 4.19.0.0-SNAPSHOT + compile + diff --git a/engine/userdata/src/main/java/org/apache/cloudstack/userdata/UserDataManagerImpl.java b/engine/userdata/src/main/java/org/apache/cloudstack/userdata/UserDataManagerImpl.java index b2ee9dfd6079..91f24fe70458 100644 --- a/engine/userdata/src/main/java/org/apache/cloudstack/userdata/UserDataManagerImpl.java +++ b/engine/userdata/src/main/java/org/apache/cloudstack/userdata/UserDataManagerImpl.java @@ -16,17 +16,29 @@ // under the License. package org.apache.cloudstack.userdata; -import com.cloud.utils.component.ManagerBase; -import com.cloud.utils.exception.CloudRuntimeException; +import java.io.UnsupportedEncodingException; +import java.net.URLDecoder; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +import org.apache.cloudstack.api.BaseCmd; import org.apache.cloudstack.framework.config.ConfigKey; import org.apache.commons.codec.binary.Base64; import org.apache.commons.lang3.StringUtils; -import java.util.HashMap; -import java.util.List; -import java.util.Map; +import com.cloud.configuration.ConfigurationManager; +import com.cloud.exception.InvalidParameterValueException; +import com.cloud.utils.component.ManagerBase; +import com.cloud.utils.exception.CloudRuntimeException; public class UserDataManagerImpl extends ManagerBase implements UserDataManager { + + + private static final int MAX_USER_DATA_LENGTH_BYTES = 2048; + private static final int MAX_HTTP_GET_LENGTH = 2 * MAX_USER_DATA_LENGTH_BYTES; + private static final int NUM_OF_2K_BLOCKS = 512; + private static final int MAX_HTTP_POST_LENGTH = NUM_OF_2K_BLOCKS * MAX_USER_DATA_LENGTH_BYTES; private List userDataProviders; private static Map userDataProvidersMap = new HashMap<>(); @@ -71,12 +83,56 @@ protected UserDataProvider getUserdataProvider(String name) { @Override public String concatenateUserData(String userdata1, String userdata2, String userdataProvider) { - byte[] userdata1Bytes = Base64.decodeBase64(userdata1.getBytes()); - byte[] userdata2Bytes = Base64.decodeBase64(userdata2.getBytes()); - String userData1Str = new String(userdata1Bytes); - String userData2Str = new String(userdata2Bytes); UserDataProvider provider = getUserdataProvider(userdataProvider); - String appendUserData = provider.appendUserData(userData1Str, userData2Str); + String appendUserData = provider.appendUserData(userdata1, userdata2); return Base64.encodeBase64String(appendUserData.getBytes()); } + + @Override + public String validateUserData(String userData, BaseCmd.HTTPMethod httpmethod) { + byte[] decodedUserData = null; + if (userData != null) { + + if (userData.contains("%")) { + try { + userData = URLDecoder.decode(userData, "UTF-8"); + } catch (UnsupportedEncodingException e) { + throw new InvalidParameterValueException("Url decoding of userdata failed."); + } + } + + if (!Base64.isBase64(userData)) { + throw new InvalidParameterValueException("User data is not base64 encoded"); + } + // If GET, use 4K. If POST, support up to 1M. + if (httpmethod.equals(BaseCmd.HTTPMethod.GET)) { + decodedUserData = validateAndDecodeByHTTPMethod(userData, MAX_HTTP_GET_LENGTH, BaseCmd.HTTPMethod.GET); + } else if (httpmethod.equals(BaseCmd.HTTPMethod.POST)) { + decodedUserData = validateAndDecodeByHTTPMethod(userData, MAX_HTTP_POST_LENGTH, BaseCmd.HTTPMethod.POST); + } + + if (decodedUserData == null || decodedUserData.length < 1) { + throw new InvalidParameterValueException("User data is too short"); + } + // Re-encode so that the '=' paddings are added if necessary since 'isBase64' does not require it, but python does on the VR. + return Base64.encodeBase64String(decodedUserData); + } + return null; + } + + private byte[] validateAndDecodeByHTTPMethod(String userData, int maxHTTPLength, BaseCmd.HTTPMethod httpMethod) { + byte[] decodedUserData = null; + + if (userData.length() >= maxHTTPLength) { + throw new InvalidParameterValueException(String.format("User data is too long for an http %s request", httpMethod.toString())); + } + if (userData.length() > ConfigurationManager.VM_USERDATA_MAX_LENGTH.value()) { + throw new InvalidParameterValueException("User data has exceeded configurable max length : " + ConfigurationManager.VM_USERDATA_MAX_LENGTH.value()); + } + decodedUserData = Base64.decodeBase64(userData.getBytes()); + if (decodedUserData.length > maxHTTPLength) { + throw new InvalidParameterValueException(String.format("User data is too long for http %s request", httpMethod.toString())); + } + return decodedUserData; + } } diff --git a/engine/userdata/src/main/java/org/apache/cloudstack/userdata/UserDataProvider.java b/engine/userdata/src/main/java/org/apache/cloudstack/userdata/UserDataProvider.java index 9ac577b54ef7..4cdcd516695d 100644 --- a/engine/userdata/src/main/java/org/apache/cloudstack/userdata/UserDataProvider.java +++ b/engine/userdata/src/main/java/org/apache/cloudstack/userdata/UserDataProvider.java @@ -21,8 +21,8 @@ public interface UserDataProvider { /** * Append user data into a single user data. - * NOTE: userData1 and userData2 are decoded user data strings + * NOTE: userData1 and userData2 are Base64 encoded user data strings * @return a non-encrypted string containing both user data inputs */ - String appendUserData(String userData1, String userData2); + String appendUserData(String encodedUserData1, String encodedUserData2); } diff --git a/engine/userdata/src/test/java/org/apache/cloudstack/userdata/UserDataManagerImplTest.java b/engine/userdata/src/test/java/org/apache/cloudstack/userdata/UserDataManagerImplTest.java new file mode 100644 index 000000000000..67e7b38e37d0 --- /dev/null +++ b/engine/userdata/src/test/java/org/apache/cloudstack/userdata/UserDataManagerImplTest.java @@ -0,0 +1,59 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. +package org.apache.cloudstack.userdata; + +import static org.junit.Assert.assertEquals; + +import java.nio.charset.StandardCharsets; + +import org.apache.cloudstack.api.BaseCmd; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.mockito.InjectMocks; +import org.mockito.Spy; +import org.mockito.junit.MockitoJUnitRunner; + +@RunWith(MockitoJUnitRunner.class) +public class UserDataManagerImplTest { + + @Spy + @InjectMocks + private UserDataManagerImpl userDataManager; + + @Test + public void testValidateBase64WithoutPadding() { + // fo should be encoded in base64 either as Zm8 or Zm8= + String encodedUserdata = "Zm8"; + String encodedUserdataWithPadding = "Zm8="; + + // Verify that we accept both but return the padded version + assertEquals("validate return the value with padding", encodedUserdataWithPadding, userDataManager.validateUserData(encodedUserdata, BaseCmd.HTTPMethod.GET)); + assertEquals("validate return the value with padding", encodedUserdataWithPadding, userDataManager.validateUserData(encodedUserdataWithPadding, BaseCmd.HTTPMethod.GET)); + } + + @Test + public void testValidateUrlEncodedBase64() { + // fo should be encoded in base64 either as Zm8 or Zm8= + String encodedUserdata = "Zm+8/w8="; + String urlEncodedUserdata = java.net.URLEncoder.encode(encodedUserdata, StandardCharsets.UTF_8); + + // Verify that we accept both but return the padded version + assertEquals("validate return the value with padding", encodedUserdata, userDataManager.validateUserData(encodedUserdata, BaseCmd.HTTPMethod.GET)); + assertEquals("validate return the value with padding", encodedUserdata, userDataManager.validateUserData(urlEncodedUserdata, BaseCmd.HTTPMethod.GET)); + } + +} diff --git a/server/src/main/java/com/cloud/configuration/ConfigurationManagerImpl.java b/server/src/main/java/com/cloud/configuration/ConfigurationManagerImpl.java index c89b0e2e3b12..890fb1195e2e 100644 --- a/server/src/main/java/com/cloud/configuration/ConfigurationManagerImpl.java +++ b/server/src/main/java/com/cloud/configuration/ConfigurationManagerImpl.java @@ -460,7 +460,6 @@ public class ConfigurationManagerImpl extends ManagerBase implements Configurati protected Set configValuesForValidation; private Set weightBasedParametersForValidation; private Set overprovisioningFactorsForValidation; - public static final String VM_USERDATA_MAX_LENGTH_STRING = "vm.userdata.max.length"; public static final ConfigKey SystemVMUseLocalStorage = new ConfigKey(Boolean.class, "system.vm.use.local.storage", "Advanced", "false", "Indicates whether to use local storage pools or shared storage pools for system VMs.", false, ConfigKey.Scope.Zone, null); @@ -491,8 +490,6 @@ public class ConfigurationManagerImpl extends ManagerBase implements Configurati public static ConfigKey VM_SERVICE_OFFERING_MAX_RAM_SIZE = new ConfigKey("Advanced", Integer.class, "vm.serviceoffering.ram.size.max", "0", "Maximum RAM size in " + "MB for vm service offering. If 0 - no limitation", true); - public static final ConfigKey VM_USERDATA_MAX_LENGTH = new ConfigKey("Advanced", Integer.class, VM_USERDATA_MAX_LENGTH_STRING, "32768", - "Max length of vm userdata after base64 decoding. Default is 32768 and maximum is 1048576", true); public static final ConfigKey MIGRATE_VM_ACROSS_CLUSTERS = new ConfigKey(Boolean.class, "migrate.vm.across.clusters", "Advanced", "false", "Indicates whether the VM can be migrated to different cluster if no host is found in same cluster",true, ConfigKey.Scope.Zone, null); diff --git a/server/src/main/java/com/cloud/server/ManagementServerImpl.java b/server/src/main/java/com/cloud/server/ManagementServerImpl.java index 913063c8e639..c16dc4eb2f47 100644 --- a/server/src/main/java/com/cloud/server/ManagementServerImpl.java +++ b/server/src/main/java/com/cloud/server/ManagementServerImpl.java @@ -16,12 +16,7 @@ // under the License. package com.cloud.server; -import static com.cloud.configuration.ConfigurationManagerImpl.VM_USERDATA_MAX_LENGTH; -import static com.cloud.vm.UserVmManager.MAX_USER_DATA_LENGTH_BYTES; - -import java.io.UnsupportedEncodingException; import java.lang.reflect.Field; -import java.net.URLDecoder; import java.util.ArrayList; import java.util.Arrays; import java.util.Calendar; @@ -56,7 +51,6 @@ import org.apache.cloudstack.annotation.dao.AnnotationDao; import org.apache.cloudstack.api.ApiCommandResourceType; import org.apache.cloudstack.api.ApiConstants; -import org.apache.cloudstack.api.BaseCmd; import org.apache.cloudstack.api.command.admin.account.CreateAccountCmd; import org.apache.cloudstack.api.command.admin.account.DeleteAccountCmd; import org.apache.cloudstack.api.command.admin.account.DisableAccountCmd; @@ -611,6 +605,7 @@ import org.apache.cloudstack.storage.datastore.db.TemplateDataStoreVO; import org.apache.cloudstack.storage.datastore.db.VolumeDataStoreDao; import org.apache.cloudstack.storage.datastore.db.VolumeDataStoreVO; +import org.apache.cloudstack.userdata.UserDataManager; import org.apache.cloudstack.utils.CloudStackVersion; import org.apache.cloudstack.utils.identity.ManagementServerNode; import org.apache.commons.codec.binary.Base64; @@ -620,13 +615,13 @@ import com.cloud.agent.AgentManager; import com.cloud.agent.api.Answer; -import com.cloud.agent.api.Command; import com.cloud.agent.api.CheckGuestOsMappingAnswer; import com.cloud.agent.api.CheckGuestOsMappingCommand; -import com.cloud.agent.api.GetVncPortAnswer; -import com.cloud.agent.api.GetVncPortCommand; +import com.cloud.agent.api.Command; import com.cloud.agent.api.GetHypervisorGuestOsNamesAnswer; import com.cloud.agent.api.GetHypervisorGuestOsNamesCommand; +import com.cloud.agent.api.GetVncPortAnswer; +import com.cloud.agent.api.GetVncPortCommand; import com.cloud.agent.api.PatchSystemVmAnswer; import com.cloud.agent.api.PatchSystemVmCommand; import com.cloud.agent.api.proxy.AllowConsoleAccessCommand; @@ -696,7 +691,6 @@ import com.cloud.host.dao.HostDao; import com.cloud.host.dao.HostDetailsDao; import com.cloud.host.dao.HostTagsDao; -import com.cloud.hypervisor.Hypervisor; import com.cloud.hypervisor.Hypervisor.HypervisorType; import com.cloud.hypervisor.HypervisorCapabilities; import com.cloud.hypervisor.HypervisorCapabilitiesVO; @@ -779,6 +773,7 @@ import com.cloud.utils.crypt.DBEncryptionUtil; import com.cloud.utils.db.DB; import com.cloud.utils.db.Filter; +import com.cloud.utils.db.GenericSearchBuilder; import com.cloud.utils.db.GlobalLock; import com.cloud.utils.db.JoinBuilder; import com.cloud.utils.db.JoinBuilder.JoinType; @@ -788,7 +783,6 @@ import com.cloud.utils.db.TransactionCallbackNoReturn; import com.cloud.utils.db.TransactionStatus; import com.cloud.utils.db.UUIDManager; -import com.cloud.utils.db.GenericSearchBuilder; import com.cloud.utils.exception.CloudRuntimeException; import com.cloud.utils.fsm.StateMachine2; import com.cloud.utils.net.MacAddress; @@ -827,10 +821,6 @@ public class ManagementServerImpl extends ManagerBase implements ManagementServe static final ConfigKey humanReadableSizes = new ConfigKey("Advanced", Boolean.class, "display.human.readable.sizes", "true", "Enables outputting human readable byte sizes to logs and usage records.", false, ConfigKey.Scope.Global); public static final ConfigKey customCsIdentifier = new ConfigKey("Advanced", String.class, "custom.cs.identifier", UUID.randomUUID().toString().split("-")[0].substring(4), "Custom identifier for the cloudstack installation", true, ConfigKey.Scope.Global); private static final VirtualMachine.Type []systemVmTypes = { VirtualMachine.Type.SecondaryStorageVm, VirtualMachine.Type.ConsoleProxy}; - - private static final int MAX_HTTP_GET_LENGTH = 2 * MAX_USER_DATA_LENGTH_BYTES; - private static final int NUM_OF_2K_BLOCKS = 512; - private static final int MAX_HTTP_POST_LENGTH = NUM_OF_2K_BLOCKS * MAX_USER_DATA_LENGTH_BYTES; private static final List LIVE_MIGRATION_SUPPORTING_HYPERVISORS = List.of(HypervisorType.Hyperv, HypervisorType.KVM, HypervisorType.LXC, HypervisorType.Ovm, HypervisorType.Ovm3, HypervisorType.Simulator, HypervisorType.VMware, HypervisorType.XenServer); @@ -982,6 +972,8 @@ public class ManagementServerImpl extends ManagerBase implements ManagementServe protected VMTemplateDao templateDao; @Inject protected AnnotationDao annotationDao; + @Inject + UserDataManager userDataManager; private LockControllerListener _lockControllerListener; private final ScheduledExecutorService _eventExecutor = Executors.newScheduledThreadPool(1, new NamedThreadFactory("EventChecker")); @@ -999,7 +991,7 @@ public class ManagementServerImpl extends ManagerBase implements ManagementServe protected List _planners; - private final List supportedHypervisors = new ArrayList(); + private final List supportedHypervisors = new ArrayList(); public List getPlanners() { return _planners; @@ -4711,58 +4703,11 @@ public UserData registerUserData(final RegisterUserDataCmd cmd) { String userdata = cmd.getUserData(); final String params = cmd.getParams(); - userdata = validateUserData(userdata, cmd.getHttpMethod()); + userdata = userDataManager.validateUserData(userdata, cmd.getHttpMethod()); return createAndSaveUserData(name, userdata, params, owner); } - private String validateUserData(String userData, BaseCmd.HTTPMethod httpmethod) { - byte[] decodedUserData = null; - if (userData != null) { - - if (userData.contains("%")) { - try { - userData = URLDecoder.decode(userData, "UTF-8"); - } catch (UnsupportedEncodingException e) { - throw new InvalidParameterValueException("Url decoding of userdata failed."); - } - } - - if (!Base64.isBase64(userData)) { - throw new InvalidParameterValueException("User data is not base64 encoded"); - } - // If GET, use 4K. If POST, support up to 1M. - if (httpmethod.equals(BaseCmd.HTTPMethod.GET)) { - decodedUserData = validateAndDecodeByHTTPmethod(userData, MAX_HTTP_GET_LENGTH, BaseCmd.HTTPMethod.GET); - } else if (httpmethod.equals(BaseCmd.HTTPMethod.POST)) { - decodedUserData = validateAndDecodeByHTTPmethod(userData, MAX_HTTP_POST_LENGTH, BaseCmd.HTTPMethod.POST); - } - - if (decodedUserData == null || decodedUserData.length < 1) { - throw new InvalidParameterValueException("User data is too short"); - } - // Re-encode so that the '=' paddings are added if necessary since 'isBase64' does not require it, but python does on the VR. - return Base64.encodeBase64String(decodedUserData); - } - return null; - } - - private byte[] validateAndDecodeByHTTPmethod(String userData, int maxHTTPlength, BaseCmd.HTTPMethod httpMethod) { - byte[] decodedUserData = null; - - if (userData.length() >= maxHTTPlength) { - throw new InvalidParameterValueException(String.format("User data is too long for an http %s request", httpMethod.toString())); - } - if (userData.length() > VM_USERDATA_MAX_LENGTH.value()) { - throw new InvalidParameterValueException("User data has exceeded configurable max length : " + VM_USERDATA_MAX_LENGTH.value()); - } - decodedUserData = Base64.decodeBase64(userData.getBytes()); - if (decodedUserData.length > maxHTTPlength) { - throw new InvalidParameterValueException(String.format("User data is too long for http %s request", httpMethod.toString())); - } - return decodedUserData; - } - /** * @param cmd * @param owner diff --git a/server/src/main/java/com/cloud/vm/UserVmManager.java b/server/src/main/java/com/cloud/vm/UserVmManager.java index 39f1e5d2d283..6dd9c27e5809 100644 --- a/server/src/main/java/com/cloud/vm/UserVmManager.java +++ b/server/src/main/java/com/cloud/vm/UserVmManager.java @@ -58,8 +58,6 @@ public interface UserVmManager extends UserVmService { "Destroys the VM's root volume when the VM is destroyed.", true, ConfigKey.Scope.Domain); - static final int MAX_USER_DATA_LENGTH_BYTES = 2048; - public static final String CKS_NODE = "cksnode"; /** diff --git a/server/src/main/java/com/cloud/vm/UserVmManagerImpl.java b/server/src/main/java/com/cloud/vm/UserVmManagerImpl.java index d223fddd8752..159c230b8764 100644 --- a/server/src/main/java/com/cloud/vm/UserVmManagerImpl.java +++ b/server/src/main/java/com/cloud/vm/UserVmManagerImpl.java @@ -16,7 +16,6 @@ // under the License. package com.cloud.vm; -import static com.cloud.configuration.ConfigurationManagerImpl.VM_USERDATA_MAX_LENGTH; import static com.cloud.utils.NumbersUtil.toHumanReadableSize; import java.io.IOException; @@ -127,7 +126,6 @@ import org.apache.cloudstack.utils.bytescale.ByteScaleUtils; import org.apache.cloudstack.utils.security.ParserUtils; import org.apache.cloudstack.vm.schedule.VMScheduleManager; -import org.apache.commons.codec.binary.Base64; import org.apache.commons.collections.CollectionUtils; import org.apache.commons.collections.MapUtils; import org.apache.commons.lang.math.NumberUtils; @@ -603,10 +601,6 @@ public class UserVmManagerImpl extends ManagerBase implements UserVmManager, Vir protected static long ROOT_DEVICE_ID = 0; - private static final int MAX_HTTP_GET_LENGTH = 2 * MAX_USER_DATA_LENGTH_BYTES; - private static final int NUM_OF_2K_BLOCKS = 512; - private static final int MAX_HTTP_POST_LENGTH = NUM_OF_2K_BLOCKS * MAX_USER_DATA_LENGTH_BYTES; - @Inject private OrchestrationService _orchSrvc; @@ -947,7 +941,7 @@ public UserVm resetVMUserData(ResetVMUserDataCmd cmd) throws ResourceUnavailable userDataDetails = cmd.getUserdataDetails().toString(); } userData = finalizeUserData(userData, userDataId, template); - userData = validateUserData(userData, cmd.getHttpMethod()); + userData = userDataManager.validateUserData(userData, cmd.getHttpMethod()); userVm.setUserDataId(userDataId); userVm.setUserData(userData); @@ -2950,7 +2944,7 @@ public UserVm updateVirtualMachine(long id, String displayName, String group, Bo if (userData != null) { // check and replace newlines userData = userData.replace("\\n", ""); - userData = validateUserData(userData, httpMethod); + userData = userDataManager.validateUserData(userData, httpMethod); // update userData on domain router. updateUserdata = true; } else { @@ -4073,7 +4067,7 @@ private UserVm getUncheckedUserVmResource(DataCenter zone, String hostName, Stri _accountMgr.checkAccess(owner, AccessType.UseEntry, false, template); // check if the user data is correct - userData = validateUserData(userData, httpmethod); + userData = userDataManager.validateUserData(userData, httpmethod); // Find an SSH public key corresponding to the key pair name, if one is // given @@ -4766,55 +4760,6 @@ public void doInTransactionWithoutResult(TransactionStatus status) { } } - protected String validateUserData(String userData, HTTPMethod httpmethod) { - byte[] decodedUserData = null; - if (userData != null) { - - if (userData.contains("%")) { - try { - userData = URLDecoder.decode(userData, "UTF-8"); - } catch (UnsupportedEncodingException e) { - throw new InvalidParameterValueException("Url decoding of userdata failed."); - } - } - - if (!Base64.isBase64(userData)) { - throw new InvalidParameterValueException("User data is not base64 encoded"); - } - // If GET, use 4K. If POST, support up to 1M. - if (httpmethod.equals(HTTPMethod.GET)) { - if (userData.length() >= MAX_HTTP_GET_LENGTH) { - throw new InvalidParameterValueException("User data is too long for an http GET request"); - } - if (userData.length() > VM_USERDATA_MAX_LENGTH.value()) { - throw new InvalidParameterValueException("User data has exceeded configurable max length : " + VM_USERDATA_MAX_LENGTH.value()); - } - decodedUserData = Base64.decodeBase64(userData.getBytes()); - if (decodedUserData.length > MAX_HTTP_GET_LENGTH) { - throw new InvalidParameterValueException("User data is too long for GET request"); - } - } else if (httpmethod.equals(HTTPMethod.POST)) { - if (userData.length() >= MAX_HTTP_POST_LENGTH) { - throw new InvalidParameterValueException("User data is too long for an http POST request"); - } - if (userData.length() > VM_USERDATA_MAX_LENGTH.value()) { - throw new InvalidParameterValueException("User data has exceeded configurable max length : " + VM_USERDATA_MAX_LENGTH.value()); - } - decodedUserData = Base64.decodeBase64(userData.getBytes()); - if (decodedUserData.length > MAX_HTTP_POST_LENGTH) { - throw new InvalidParameterValueException("User data is too long for POST request"); - } - } - - if (decodedUserData == null || decodedUserData.length < 1) { - throw new InvalidParameterValueException("User data is too short"); - } - // Re-encode so that the '=' paddings are added if necessary since 'isBase64' does not require it, but python does on the VR. - return Base64.encodeBase64String(decodedUserData); - } - return null; - } - @Override @ActionEvent(eventType = EventTypes.EVENT_VM_CREATE, eventDescription = "deploying Vm", async = true) public UserVm startVirtualMachine(DeployVMCmd cmd) throws ResourceUnavailableException, InsufficientCapacityException, ConcurrentOperationException, ResourceAllocationException { @@ -5852,6 +5797,7 @@ public UserVm createVirtualMachine(DeployVMCmd cmd) throws InsufficientCapacityE } String userData = cmd.getUserData(); + userData = userDataManager.validateUserData(userData, cmd.getHttpMethod()); Long userDataId = cmd.getUserdataId(); String userDataDetails = null; if (MapUtils.isNotEmpty(cmd.getUserdataDetails())) { diff --git a/server/src/test/java/com/cloud/server/ManagementServerImplTest.java b/server/src/test/java/com/cloud/server/ManagementServerImplTest.java index cf8df1ad3721..1de5b256dbd7 100644 --- a/server/src/test/java/com/cloud/server/ManagementServerImplTest.java +++ b/server/src/test/java/com/cloud/server/ManagementServerImplTest.java @@ -22,6 +22,35 @@ import static org.mockito.Mockito.lenient; import static org.mockito.Mockito.when; +import java.util.ArrayList; +import java.util.List; + +import org.apache.cloudstack.annotation.dao.AnnotationDao; +import org.apache.cloudstack.api.ApiConstants; +import org.apache.cloudstack.api.BaseCmd; +import org.apache.cloudstack.api.command.user.address.ListPublicIpAddressesCmd; +import org.apache.cloudstack.api.command.user.ssh.RegisterSSHKeyPairCmd; +import org.apache.cloudstack.api.command.user.userdata.DeleteUserDataCmd; +import org.apache.cloudstack.api.command.user.userdata.ListUserDataCmd; +import org.apache.cloudstack.api.command.user.userdata.RegisterUserDataCmd; +import org.apache.cloudstack.context.CallContext; +import org.apache.cloudstack.framework.config.ConfigKey; +import org.apache.cloudstack.userdata.UserDataManager; +import org.junit.After; +import org.junit.Assert; +import org.junit.Before; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.mockito.Mock; +import org.mockito.Mockito; +import org.mockito.MockitoAnnotations; +import org.mockito.Spy; +import org.powermock.api.mockito.PowerMockito; +import org.powermock.core.classloader.annotations.PrepareForTest; +import org.powermock.modules.junit4.PowerMockRunner; +import org.powermock.reflect.Whitebox; +import org.springframework.test.util.ReflectionTestUtils; + import com.cloud.dc.Vlan.VlanType; import com.cloud.exception.InvalidParameterValueException; import com.cloud.host.DetailVO; @@ -49,37 +78,8 @@ import com.cloud.utils.exception.CloudRuntimeException; import com.cloud.vm.UserVmDetailVO; import com.cloud.vm.UserVmVO; -import com.cloud.vm.dao.UserVmDetailsDao; import com.cloud.vm.dao.UserVmDao; - -import org.apache.cloudstack.annotation.dao.AnnotationDao; -import org.apache.cloudstack.api.ApiConstants; -import org.apache.cloudstack.api.BaseCmd; -import org.apache.cloudstack.api.command.user.address.ListPublicIpAddressesCmd; -import org.apache.cloudstack.api.command.user.ssh.RegisterSSHKeyPairCmd; -import org.apache.cloudstack.api.command.user.userdata.DeleteUserDataCmd; -import org.apache.cloudstack.api.command.user.userdata.ListUserDataCmd; -import org.apache.cloudstack.api.command.user.userdata.RegisterUserDataCmd; -import org.apache.cloudstack.context.CallContext; -import org.apache.cloudstack.framework.config.ConfigKey; - -import org.junit.After; -import org.junit.Assert; -import org.junit.Before; -import org.junit.Test; -import org.junit.runner.RunWith; -import org.mockito.Mock; -import org.mockito.Mockito; -import org.mockito.MockitoAnnotations; -import org.mockito.Spy; -import org.powermock.api.mockito.PowerMockito; -import org.powermock.core.classloader.annotations.PrepareForTest; -import org.powermock.modules.junit4.PowerMockRunner; -import org.powermock.reflect.Whitebox; -import org.springframework.test.util.ReflectionTestUtils; - -import java.util.ArrayList; -import java.util.List; +import com.cloud.vm.dao.UserVmDetailsDao; @RunWith(PowerMockRunner.class) @PrepareForTest(CallContext.class) @@ -121,6 +121,9 @@ public class ManagementServerImplTest { @Mock UserVmDao _userVmDao; + @Mock + UserDataManager userDataManager; + @Spy ManagementServerImpl spy = new ManagementServerImpl(); @@ -145,6 +148,7 @@ public void setup() { spy.annotationDao = annotationDao; spy._UserVmDetailsDao = userVmDetailsDao; spy._detailsDao = hostDetailsDao; + spy.userDataManager = userDataManager; } @After @@ -304,13 +308,15 @@ public void testSuccessfulRegisterUserdata() { when(callContextMock.getCallingAccount()).thenReturn(account); when(_accountMgr.finalizeOwner(nullable(Account.class), nullable(String.class), nullable(Long.class), nullable(Long.class))).thenReturn(account); + String testUserData = "testUserdata"; RegisterUserDataCmd cmd = Mockito.mock(RegisterUserDataCmd.class); - when(cmd.getUserData()).thenReturn("testUserdata"); + when(cmd.getUserData()).thenReturn(testUserData); when(cmd.getName()).thenReturn("testName"); when(cmd.getHttpMethod()).thenReturn(BaseCmd.HTTPMethod.GET); when(_userDataDao.findByName(account.getAccountId(), account.getDomainId(), "testName")).thenReturn(null); - when(_userDataDao.findByUserData(account.getAccountId(), account.getDomainId(), "testUserdata")).thenReturn(null); + when(_userDataDao.findByUserData(account.getAccountId(), account.getDomainId(), testUserData)).thenReturn(null); + when(userDataManager.validateUserData(testUserData,BaseCmd.HTTPMethod.GET)).thenReturn(testUserData); UserData userData = spy.registerUserData(cmd); Assert.assertEquals("testName", userData.getName()); diff --git a/server/src/test/java/com/cloud/vm/UserVmManagerImplTest.java b/server/src/test/java/com/cloud/vm/UserVmManagerImplTest.java index f91b52b867b2..ef001906d981 100644 --- a/server/src/test/java/com/cloud/vm/UserVmManagerImplTest.java +++ b/server/src/test/java/com/cloud/vm/UserVmManagerImplTest.java @@ -836,10 +836,13 @@ public void testResetVMUserDataSuccessResetWithUserdata() { when(templateDao.findByIdIncludingRemoved(2L)).thenReturn(template); when(template.getUserDataId()).thenReturn(null); - when(cmd.getUserData()).thenReturn("testUserdata"); + String testUserData = "testUserdata"; + when(cmd.getUserData()).thenReturn(testUserData); when(cmd.getUserdataId()).thenReturn(null); when(cmd.getHttpMethod()).thenReturn(HTTPMethod.GET); + when(userDataManager.validateUserData(testUserData, HTTPMethod.GET)).thenReturn(testUserData); + try { doNothing().when(userVmManagerImpl).updateUserData(userVmVO); userVmManagerImpl.resetVMUserData(cmd); @@ -873,12 +876,15 @@ public void testResetVMUserDataSuccessResetWithUserdataId() { when(templateDao.findByIdIncludingRemoved(2L)).thenReturn(template); when(template.getUserDataId()).thenReturn(null); + String testUserData = "testUserdata"; when(cmd.getUserdataId()).thenReturn(1L); UserDataVO apiUserDataVO = Mockito.mock(UserDataVO.class); when(userDataDao.findById(1L)).thenReturn(apiUserDataVO); - when(apiUserDataVO.getUserData()).thenReturn("testUserdata"); + when(apiUserDataVO.getUserData()).thenReturn(testUserData); when(cmd.getHttpMethod()).thenReturn(HTTPMethod.GET); + when(userDataManager.validateUserData(testUserData, HTTPMethod.GET)).thenReturn(testUserData); + try { doNothing().when(userVmManagerImpl).updateUserData(userVmVO); userVmManagerImpl.resetVMUserData(cmd); diff --git a/server/src/test/java/com/cloud/vm/UserVmManagerTest.java b/server/src/test/java/com/cloud/vm/UserVmManagerTest.java index 7cc2c8a6be13..a0ad32153688 100644 --- a/server/src/test/java/com/cloud/vm/UserVmManagerTest.java +++ b/server/src/test/java/com/cloud/vm/UserVmManagerTest.java @@ -18,7 +18,6 @@ package com.cloud.vm; import static org.hamcrest.Matchers.instanceOf; -import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertFalse; import static org.junit.Assert.assertThat; import static org.junit.Assert.assertTrue; @@ -38,7 +37,6 @@ import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; -import java.io.UnsupportedEncodingException; import java.lang.reflect.Field; import java.util.ArrayList; import java.util.HashMap; @@ -48,7 +46,6 @@ import org.apache.cloudstack.acl.ControlledEntity; import org.apache.cloudstack.acl.SecurityChecker.AccessType; -import org.apache.cloudstack.api.BaseCmd; import org.apache.cloudstack.api.command.admin.vm.AssignVMCmd; import org.apache.cloudstack.api.command.user.vm.RestoreVMCmd; import org.apache.cloudstack.api.command.user.vm.ScaleVMCmd; @@ -841,26 +838,4 @@ public void testPersistDeviceBusInfo() { _userVmMgr.persistDeviceBusInfo(_vmMock, "lsilogic"); verify(_vmDao, times(1)).saveDetails(any(UserVmVO.class)); } - - @Test - public void testValideBase64WithoutPadding() { - // fo should be encoded in base64 either as Zm8 or Zm8= - String encodedUserdata = "Zm8"; - String encodedUserdataWithPadding = "Zm8="; - - // Verify that we accept both but return the padded version - assertTrue("validate return the value with padding", encodedUserdataWithPadding.equals(_userVmMgr.validateUserData(encodedUserdata, BaseCmd.HTTPMethod.GET))); - assertTrue("validate return the value with padding", encodedUserdataWithPadding.equals(_userVmMgr.validateUserData(encodedUserdataWithPadding, BaseCmd.HTTPMethod.GET))); - } - - @Test - public void testValidateUrlEncodedBase64() throws UnsupportedEncodingException { - // fo should be encoded in base64 either as Zm8 or Zm8= - String encodedUserdata = "Zm+8/w8="; - String urlEncodedUserdata = java.net.URLEncoder.encode(encodedUserdata, "UTF-8"); - - // Verify that we accept both but return the padded version - assertEquals("validate return the value with padding", encodedUserdata, _userVmMgr.validateUserData(encodedUserdata, BaseCmd.HTTPMethod.GET)); - assertEquals("validate return the value with padding", encodedUserdata, _userVmMgr.validateUserData(urlEncodedUserdata, BaseCmd.HTTPMethod.GET)); - } } diff --git a/test/integration/smoke/test_register_userdata.py b/test/integration/smoke/test_register_userdata.py index 5c954a876ec0..c89d08e63e83 100644 --- a/test/integration/smoke/test_register_userdata.py +++ b/test/integration/smoke/test_register_userdata.py @@ -31,6 +31,8 @@ from marvin.lib.utils import (validateList, cleanup_resources) from nose.plugins.attrib import attr from marvin.codes import PASS,FAIL +import base64 +import email from marvin.lib.common import (get_domain, get_template) @@ -592,24 +594,20 @@ def test_deploy_vm_with_registered_userdata_with_override_policy_append(self): userdata and configured to VM as a multipart MIME userdata. Verify the same by SSH into VM. """ - # #!/bin/bash - # date > /provisioned + shellscript_userdata = str("#!/bin/bash\ndate > /provisioned") self.apiUserdata = UserData.register( self.apiclient, name="ApiUserdata", - userdata="IyEvYmluL2Jhc2gKZGF0ZSA+IC9wcm92aXNpb25lZA==", + userdata=base64.encodebytes(shellscript_userdata.encode()).decode(), account=self.account.name, domainid=self.account.domainid ) - # #cloud-config - # password: atomic - # chpasswd: { expire: False } - # ssh_pwauth: True + cloudconfig_userdata = str("#cloud-config\npassword: atomic\nchpasswd: { expire: False }\nssh_pwauth: True") self.templateUserdata = UserData.register( self.apiclient, name="TemplateUserdata", - userdata="I2Nsb3VkLWNvbmZpZwpwYXNzd29yZDogYXRvbWljCmNocGFzc3dkOiB7IGV4cGlyZTogRmFsc2UgfQpzc2hfcHdhdXRoOiBUcnVl", + userdata=base64.encodebytes(cloudconfig_userdata.encode()).decode(), account=self.account.name, domainid=self.account.domainid ) @@ -707,9 +705,32 @@ def test_deploy_vm_with_registered_userdata_with_override_policy_append(self): res = ssh.execute(cmd) self.debug("Verifying userdata in the VR") self.assertTrue( - "Content-Type: multipart" in str(res[2]), + res is not None and len(res) > 0, + "Resultant userdata is not valid" + ) + msg = email.message_from_string('\n'.join(res)) + self.assertTrue( + msg.is_multipart(), "Failed to match multipart userdata" ) + shellscript_userdata_found = False + cloudconfig_userdata_found = False + for part in msg.get_payload(): + content_type = part.get_content_type() + payload = part.get_payload(decode=True).decode() + if "shellscript" in content_type: + shellscript_userdata_found = shellscript_userdata == payload + elif "cloud-config" in content_type: + cloudconfig_userdata_found = cloudconfig_userdata == payload + + self.assertTrue( + shellscript_userdata_found, + "Failed to find shellscript userdata in append result" + ) + self.assertTrue( + cloudconfig_userdata_found, + "Failed to find cloud-config userdata in append result" + ) @attr(tags=['advanced', 'simulator', 'basic', 'sg', 'testnow'], required_hardware=True) def test_deploy_vm_with_registered_userdata_with_override_policy_deny(self):