diff --git a/api/src/main/java/org/apache/cloudstack/userdata/UserDataManager.java b/api/src/main/java/org/apache/cloudstack/userdata/UserDataManager.java index 2fc3acd45d15..4dfcd0a7de1b 100644 --- a/api/src/main/java/org/apache/cloudstack/userdata/UserDataManager.java +++ b/api/src/main/java/org/apache/cloudstack/userdata/UserDataManager.java @@ -16,9 +16,12 @@ // under the License. package org.apache.cloudstack.userdata; -import com.cloud.utils.component.Manager; +import org.apache.cloudstack.api.BaseCmd; import org.apache.cloudstack.framework.config.Configurable; +import com.cloud.utils.component.Manager; + public interface UserDataManager extends Manager, Configurable { String concatenateUserData(String userdata1, String userdata2, String userdataProvider); + String validateUserData(String userData, BaseCmd.HTTPMethod httpmethod); } diff --git a/engine/components-api/src/main/java/com/cloud/configuration/ConfigurationManager.java b/engine/components-api/src/main/java/com/cloud/configuration/ConfigurationManager.java index c5caa312b584..5343fb632b54 100644 --- a/engine/components-api/src/main/java/com/cloud/configuration/ConfigurationManager.java +++ b/engine/components-api/src/main/java/com/cloud/configuration/ConfigurationManager.java @@ -20,6 +20,7 @@ import java.util.Map; import java.util.Set; +import org.apache.cloudstack.framework.config.ConfigKey; import org.apache.cloudstack.framework.config.impl.ConfigurationSubGroupVO; import com.cloud.dc.ClusterVO; @@ -59,6 +60,10 @@ public interface ConfigurationManager { public static final String MESSAGE_CREATE_VLAN_IP_RANGE_EVENT = "Message.CreateVlanIpRange.Event"; public static final String MESSAGE_DELETE_VLAN_IP_RANGE_EVENT = "Message.DeleteVlanIpRange.Event"; + static final String VM_USERDATA_MAX_LENGTH_STRING = "vm.userdata.max.length"; + static final ConfigKey VM_USERDATA_MAX_LENGTH = new ConfigKey<>("Advanced", Integer.class, VM_USERDATA_MAX_LENGTH_STRING, "32768", + "Max length of vm userdata after base64 decoding. Default is 32768 and maximum is 1048576", true); + /** * @param offering * @return diff --git a/engine/userdata/cloud-init/src/main/java/org/apache/cloudstack/userdata/CloudInitUserDataProvider.java b/engine/userdata/cloud-init/src/main/java/org/apache/cloudstack/userdata/CloudInitUserDataProvider.java index c61f37a18966..65996f181a9c 100644 --- a/engine/userdata/cloud-init/src/main/java/org/apache/cloudstack/userdata/CloudInitUserDataProvider.java +++ b/engine/userdata/cloud-init/src/main/java/org/apache/cloudstack/userdata/CloudInitUserDataProvider.java @@ -19,7 +19,7 @@ import java.io.ByteArrayInputStream; import java.io.ByteArrayOutputStream; import java.io.IOException; -import java.nio.charset.StandardCharsets; +import java.io.InputStream; import java.util.Arrays; import java.util.List; import java.util.Map; @@ -35,12 +35,14 @@ import javax.mail.internet.MimeMessage; import javax.mail.internet.MimeMultipart; +import org.apache.commons.codec.binary.Base64; import org.apache.commons.collections.CollectionUtils; import org.apache.commons.lang3.StringUtils; import org.apache.log4j.Logger; import com.cloud.utils.component.AdapterBase; import com.cloud.utils.exception.CloudRuntimeException; +import com.sun.mail.util.BASE64DecoderStream; public class CloudInitUserDataProvider extends AdapterBase implements UserDataProvider { @@ -69,11 +71,11 @@ public String getName() { return "cloud-init"; } - protected boolean isGZipped(String userdata) { - if (StringUtils.isEmpty(userdata)) { + protected boolean isGZipped(String encodedUserdata) { + if (StringUtils.isEmpty(encodedUserdata)) { return false; } - byte[] data = userdata.getBytes(StandardCharsets.ISO_8859_1); + byte[] data = Base64.decodeBase64(encodedUserdata); if (data.length < 2) { return false; } @@ -82,9 +84,6 @@ protected boolean isGZipped(String userdata) { } protected String extractUserDataHeader(String userdata) { - if (isGZipped(userdata)) { - throw new CloudRuntimeException("Gzipped user data can not be used together with other user data formats"); - } List lines = Arrays.stream(userdata.split("\n")) .filter(x -> (x.startsWith("#") && !x.startsWith("##")) || (x.startsWith("Content-Type:"))) .collect(Collectors.toList()); @@ -131,7 +130,7 @@ protected FormatType getUserDataFormatType(String userdata) { private String getContentType(String userData, FormatType formatType) throws MessagingException { if (formatType == FormatType.MIME) { - MimeMessage msg = new MimeMessage(session, new ByteArrayInputStream(userData.getBytes())); + NoIdMimeMessage msg = new NoIdMimeMessage(session, new ByteArrayInputStream(userData.getBytes())); return msg.getContentType(); } if (!formatContentTypeMap.containsKey(formatType)) { @@ -141,15 +140,35 @@ private String getContentType(String userData, FormatType formatType) throws Mes return formatContentTypeMap.get(formatType); } - protected MimeBodyPart generateBodyPartMIMEMessage(String userData, FormatType formatType) throws MessagingException { + protected String getBodyPartContentAsString(BodyPart bodyPart) throws MessagingException, IOException { + Object content = bodyPart.getContent(); + if (content instanceof BASE64DecoderStream) { + return new String(((BASE64DecoderStream)bodyPart.getContent()).readAllBytes()); + } else if (content instanceof ByteArrayInputStream) { + return new String(((ByteArrayInputStream)bodyPart.getContent()).readAllBytes()); + } else if (content instanceof String) { + return (String)bodyPart.getContent(); + } + throw new CloudRuntimeException(String.format("Failed to get content for multipart data with content type: %s", getBodyPartContentType(bodyPart))); + } + + private String getBodyPartContentType(BodyPart bodyPart) throws MessagingException { + String contentType = StringUtils.defaultString(bodyPart.getDataHandler().getContentType(), bodyPart.getContentType()); + return contentType.contains(";") ? contentType.substring(0, contentType.indexOf(';')) : contentType; + } + + protected MimeBodyPart generateBodyPartMimeMessage(String userData, String contentType) throws MessagingException { MimeBodyPart bodyPart = new MimeBodyPart(); - String contentType = getContentType(userData, formatType); bodyPart.setContent(userData, contentType); bodyPart.addHeader("Content-Transfer-Encoding", "base64"); return bodyPart; } - private Multipart getMessageContent(MimeMessage message) { + protected MimeBodyPart generateBodyPartMimeMessage(String userData, FormatType formatType) throws MessagingException { + return generateBodyPartMimeMessage(userData, getContentType(userData, formatType)); + } + + private Multipart getMessageContent(NoIdMimeMessage message) { Multipart messageContent; try { messageContent = (MimeMultipart) message.getContent(); @@ -159,40 +178,83 @@ private Multipart getMessageContent(MimeMessage message) { return messageContent; } - private void addBodyPartsToMessageContentFromUserDataContent(Multipart messageContent, - MimeMessage msgFromUserdata) throws MessagingException, IOException { - Multipart msgFromUserdataParts = (MimeMultipart) msgFromUserdata.getContent(); - int count = msgFromUserdataParts.getCount(); - int i = 0; - while (i < count) { - BodyPart bodyPart = msgFromUserdataParts.getBodyPart(0); - messageContent.addBodyPart(bodyPart); - i++; + private void addBodyPartToMultipart(Multipart existingMultipart, MimeBodyPart bodyPart) throws MessagingException, IOException { + boolean added = false; + final int existingCount = existingMultipart.getCount(); + for (int j = 0; j < existingCount; ++j) { + MimeBodyPart existingBodyPart = (MimeBodyPart)existingMultipart.getBodyPart(j); + String existingContentType = getBodyPartContentType(existingBodyPart); + String newContentType = getBodyPartContentType(bodyPart); + if (existingContentType.equals(newContentType)) { + String existingContent = getBodyPartContentAsString(existingBodyPart); + String newContent = getBodyPartContentAsString(bodyPart); + // generating a combined content MimeBodyPart to replace + MimeBodyPart combinedBodyPart = generateBodyPartMimeMessage( + simpleAppendSameFormatTypeUserData(existingContent, newContent), existingContentType); + existingMultipart.removeBodyPart(j); + existingMultipart.addBodyPart(combinedBodyPart, j); + added = true; + break; + } + } + if (!added) { + existingMultipart.addBodyPart(bodyPart); + } + } + + private void addBodyPartsToMessageContentFromUserDataContent(Multipart existingMultipart, + NoIdMimeMessage msgFromUserdata) throws MessagingException, IOException { + MimeMultipart newMultipart = (MimeMultipart)msgFromUserdata.getContent(); + final int existingCount = existingMultipart.getCount(); + final int newCount = newMultipart.getCount(); + for (int i = 0; i < newCount; ++i) { + BodyPart bodyPart = newMultipart.getBodyPart(i); + if (existingCount == 0) { + existingMultipart.addBodyPart(bodyPart); + continue; + } + addBodyPartToMultipart(existingMultipart, (MimeBodyPart)bodyPart); } } - private MimeMessage createMultipartMessageAddingUserdata(String userData, FormatType formatType, - MimeMessage message) throws MessagingException, IOException { - MimeMessage newMessage = new MimeMessage(session); + private NoIdMimeMessage createMultipartMessageAddingUserdata(String userData, FormatType formatType, + NoIdMimeMessage message) throws MessagingException, IOException { + NoIdMimeMessage newMessage = new NoIdMimeMessage(session); Multipart messageContent = getMessageContent(message); if (formatType == FormatType.MIME) { - MimeMessage msgFromUserdata = new MimeMessage(session, new ByteArrayInputStream(userData.getBytes())); + NoIdMimeMessage msgFromUserdata = new NoIdMimeMessage(session, new ByteArrayInputStream(userData.getBytes())); addBodyPartsToMessageContentFromUserDataContent(messageContent, msgFromUserdata); } else { - MimeBodyPart part = generateBodyPartMIMEMessage(userData, formatType); - messageContent.addBodyPart(part); + MimeBodyPart part = generateBodyPartMimeMessage(userData, formatType); + addBodyPartToMultipart(messageContent, part); } newMessage.setContent(messageContent); return newMessage; } + private String simpleAppendSameFormatTypeUserData(String userData1, String userData2) { + return String.format("%s\n\n%s", userData1, userData2.substring(userData2.indexOf('\n')+1)); + } + + private void checkGzipAppend(String encodedUserData1, String encodedUserData2) { + if (isGZipped(encodedUserData1) || isGZipped(encodedUserData2)) { + throw new CloudRuntimeException("Gzipped user data can not be used together with other user data formats"); + } + } + @Override - public String appendUserData(String userData1, String userData2) { + public String appendUserData(String encodedUserData1, String encodedUserData2) { try { + checkGzipAppend(encodedUserData1, encodedUserData2); + String userData1 = new String(Base64.decodeBase64(encodedUserData1)); + String userData2 = new String(Base64.decodeBase64(encodedUserData2)); FormatType formatType1 = getUserDataFormatType(userData1); FormatType formatType2 = getUserDataFormatType(userData2); - MimeMessage message = new MimeMessage(session); + if (formatType1.equals(formatType2) && List.of(FormatType.CLOUD_CONFIG, FormatType.BASH_SCRIPT).contains(formatType1)) { + return simpleAppendSameFormatTypeUserData(userData1, userData2); + } + NoIdMimeMessage message = new NoIdMimeMessage(session); message = createMultipartMessageAddingUserdata(userData1, formatType1, message); message = createMultipartMessageAddingUserdata(userData2, formatType2, message); ByteArrayOutputStream output = new ByteArrayOutputStream(); @@ -205,4 +267,20 @@ public String appendUserData(String userData1, String userData2) { throw new CloudRuntimeException(msg, e); } } + + /* This is a wrapper class just to remove Message-ID header from the resultant + multipart data which may contain server details. + */ + private class NoIdMimeMessage extends MimeMessage { + NoIdMimeMessage (Session session) { + super(session); + } + NoIdMimeMessage (Session session, InputStream is) throws MessagingException { + super(session, is); + } + @Override + protected void updateMessageID() throws MessagingException { + removeHeader("Message-ID"); + } + } } diff --git a/engine/userdata/cloud-init/src/test/java/org/apache/cloudstack/userdata/CloudInitUserDataProviderTest.java b/engine/userdata/cloud-init/src/test/java/org/apache/cloudstack/userdata/CloudInitUserDataProviderTest.java index b91438c5a360..4ca9fb7ebd67 100644 --- a/engine/userdata/cloud-init/src/test/java/org/apache/cloudstack/userdata/CloudInitUserDataProviderTest.java +++ b/engine/userdata/cloud-init/src/test/java/org/apache/cloudstack/userdata/CloudInitUserDataProviderTest.java @@ -16,11 +16,20 @@ // under the License. package org.apache.cloudstack.userdata; +import java.io.ByteArrayInputStream; import java.io.ByteArrayOutputStream; import java.io.IOException; import java.nio.charset.StandardCharsets; +import java.util.Properties; import java.util.zip.GZIPOutputStream; +import javax.mail.BodyPart; +import javax.mail.MessagingException; +import javax.mail.Session; +import javax.mail.internet.MimeMessage; +import javax.mail.internet.MimeMultipart; + +import org.apache.commons.codec.binary.Base64; import org.junit.Assert; import org.junit.Test; @@ -34,6 +43,33 @@ public class CloudInitUserDataProviderTest { "runcmd:\n" + " - echo 'TestVariable {{ ds.meta_data.variable1 }}' >> /tmp/variable\n" + " - echo 'Hostname {{ ds.meta_data.public_hostname }}' > /tmp/hostname"; + private final static String CLOUD_CONFIG_USERDATA1 = "#cloud-config\n" + + "password: atomic\n" + + "chpasswd: { expire: False }\n" + + "ssh_pwauth: True"; + private final static String SHELL_SCRIPT_USERDATA = "#!/bin/bash\n" + + "date > /provisioned"; + private final static String SHELL_SCRIPT_USERDATA1 = "#!/bin/bash\n" + + "mkdir /tmp/test"; + private final static String SINGLE_BODYPART_CLOUDCONFIG_MULTIPART_USERDATA = + "Content-Type: multipart/mixed; boundary=\"//\"\n" + + "MIME-Version: 1.0\n" + + "\n" + + "--//\n" + + "Content-Type: text/cloud-config; charset=\"us-ascii\"\n" + + "MIME-Version: 1.0\n" + + "Content-Transfer-Encoding: 7bit\n" + + "Content-Disposition: attachment; filename=\"cloud-config.txt\"\n" + + "\n" + + "#cloud-config\n" + + "\n" + + "# Upgrade the instance on first boot\n" + + "# (ie run apt-get upgrade)\n" + + "#\n" + + "# Default: false\n" + + "# Aliases: apt_upgrade\n" + + "package_upgrade: true"; + private static final Session session = Session.getDefaultInstance(new Properties()); @Test public void testGetUserDataFormatType() { @@ -54,51 +90,81 @@ public void testGetUserDataFormatTypeInvalidType() { provider.getUserDataFormatType(userdata); } + private MimeMultipart getCheckedMultipartFromMultipartData(String multipartUserData, int count) { + MimeMultipart multipart = null; + Assert.assertTrue(multipartUserData.contains("Content-Type: multipart")); + try { + MimeMessage msgFromUserdata = new MimeMessage(session, + new ByteArrayInputStream(multipartUserData.getBytes())); + multipart = (MimeMultipart)msgFromUserdata.getContent(); + Assert.assertEquals(count, multipart.getCount()); + } catch (MessagingException | IOException e) { + Assert.fail(String.format("Failed with exception, %s", e.getMessage())); + } + return multipart; + } + @Test public void testAppendUserData() { - String templateData = "#cloud-config\n" + - "password: atomic\n" + - "chpasswd: { expire: False }\n" + - "ssh_pwauth: True"; - String vmData = "#!/bin/bash\n" + - "date > /provisioned"; - String multipartUserData = provider.appendUserData(templateData, vmData); - Assert.assertTrue(multipartUserData.contains("Content-Type: multipart")); + String multipartUserData = provider.appendUserData(Base64.encodeBase64String(CLOUD_CONFIG_USERDATA1.getBytes()), + Base64.encodeBase64String(SHELL_SCRIPT_USERDATA.getBytes())); + getCheckedMultipartFromMultipartData(multipartUserData, 2); + } + + @Test + public void testAppendSameShellScriptTypeUserData() { + String result = SHELL_SCRIPT_USERDATA + "\n\n" + + SHELL_SCRIPT_USERDATA1.replace("#!/bin/bash\n", ""); + String appendUserData = provider.appendUserData(Base64.encodeBase64String(SHELL_SCRIPT_USERDATA.getBytes()), + Base64.encodeBase64String(SHELL_SCRIPT_USERDATA1.getBytes())); + Assert.assertEquals(result, appendUserData); + } + + @Test + public void testAppendSameCloudConfigTypeUserData() { + String result = CLOUD_CONFIG_USERDATA + "\n\n" + + CLOUD_CONFIG_USERDATA1.replace("#cloud-config\n", ""); + String appendUserData = provider.appendUserData(Base64.encodeBase64String(CLOUD_CONFIG_USERDATA.getBytes()), + Base64.encodeBase64String(CLOUD_CONFIG_USERDATA1.getBytes())); + Assert.assertEquals(result, appendUserData); } @Test public void testAppendUserDataMIMETemplateData() { - String templateData = "Content-Type: multipart/mixed; boundary=\"//\"\n" + - "MIME-Version: 1.0\n" + - "\n" + - "--//\n" + - "Content-Type: text/cloud-config; charset=\"us-ascii\"\n" + - "MIME-Version: 1.0\n" + - "Content-Transfer-Encoding: 7bit\n" + - "Content-Disposition: attachment; filename=\"cloud-config.txt\"\n" + - "\n" + - "#cloud-config\n" + - "\n" + - "# Upgrade the instance on first boot\n" + - "# (ie run apt-get upgrade)\n" + - "#\n" + - "# Default: false\n" + - "# Aliases: apt_upgrade\n" + - "package_upgrade: true"; - String vmData = "#!/bin/bash\n" + - "date > /provisioned"; - String multipartUserData = provider.appendUserData(templateData, vmData); - Assert.assertTrue(multipartUserData.contains("Content-Type: multipart")); + String multipartUserData = provider.appendUserData( + Base64.encodeBase64String(SINGLE_BODYPART_CLOUDCONFIG_MULTIPART_USERDATA.getBytes()), + Base64.encodeBase64String(SHELL_SCRIPT_USERDATA.getBytes())); + getCheckedMultipartFromMultipartData(multipartUserData, 2); + } + + @Test + public void testAppendUserDataExistingMultipartWithSameType() { + String templateData = provider.appendUserData(Base64.encodeBase64String(CLOUD_CONFIG_USERDATA1.getBytes()), + Base64.encodeBase64String(SHELL_SCRIPT_USERDATA.getBytes())); + String multipartUserData = provider.appendUserData(Base64.encodeBase64String(templateData.getBytes()), + Base64.encodeBase64String(SHELL_SCRIPT_USERDATA1.getBytes())); + String resultantShellScript = SHELL_SCRIPT_USERDATA + "\n\n" + + SHELL_SCRIPT_USERDATA1.replace("#!/bin/bash\n", ""); + MimeMultipart mimeMultipart = getCheckedMultipartFromMultipartData(multipartUserData, 2); + try { + for (int i = 0; i < mimeMultipart.getCount(); ++i) { + BodyPart bodyPart = mimeMultipart.getBodyPart(i); + if (bodyPart.getContentType().startsWith("text/x-shellscript")) { + Assert.assertEquals(resultantShellScript, provider.getBodyPartContentAsString(bodyPart)); + } else if (bodyPart.getContentType().startsWith("text/cloud-config")) { + Assert.assertEquals(CLOUD_CONFIG_USERDATA1, provider.getBodyPartContentAsString(bodyPart)); + } + } + } catch (MessagingException | IOException | CloudRuntimeException e) { + Assert.fail(String.format("Failed with exception, %s", e.getMessage())); + } } @Test(expected = CloudRuntimeException.class) public void testAppendUserDataInvalidUserData() { - String templateData = "password: atomic\n" + - "chpasswd: { expire: False }\n" + - "ssh_pwauth: True"; - String vmData = "#!/bin/bash\n" + - "date > /provisioned"; - provider.appendUserData(templateData, vmData); + String templateData = CLOUD_CONFIG_USERDATA1.replace("#cloud-config\n", ""); + provider.appendUserData(Base64.encodeBase64String(templateData.getBytes()), + Base64.encodeBase64String(SHELL_SCRIPT_USERDATA.getBytes())); } @Test @@ -106,7 +172,7 @@ public void testIsGzippedUserDataWithCloudConfigData() { Assert.assertFalse(provider.isGZipped(CLOUD_CONFIG_USERDATA)); } - private String createGzipDataAsString() throws IOException { + private String createBase64EncodedGzipDataAsString() throws IOException { byte[] input = CLOUD_CONFIG_USERDATA.getBytes(StandardCharsets.ISO_8859_1); ByteArrayOutputStream arrayOutputStream = new ByteArrayOutputStream(); @@ -114,13 +180,13 @@ private String createGzipDataAsString() throws IOException { outputStream.write(input,0, input.length); outputStream.close(); - return arrayOutputStream.toString(StandardCharsets.ISO_8859_1); + return Base64.encodeBase64String(arrayOutputStream.toByteArray()); } @Test public void testIsGzippedUserDataWithValidGzipData() { try { - String gzipped = createGzipDataAsString(); + String gzipped = createBase64EncodedGzipDataAsString(); Assert.assertTrue(provider.isGZipped(gzipped)); } catch (IOException e) { Assert.fail(e.getMessage()); @@ -130,7 +196,8 @@ public void testIsGzippedUserDataWithValidGzipData() { @Test(expected = CloudRuntimeException.class) public void testAppendUserDataWithGzippedData() { try { - provider.appendUserData(CLOUD_CONFIG_USERDATA, createGzipDataAsString()); + provider.appendUserData(Base64.encodeBase64String(CLOUD_CONFIG_USERDATA.getBytes()), + createBase64EncodedGzipDataAsString()); Assert.fail("Gzipped data shouldn't be appended with other data"); } catch (IOException e) { Assert.fail("Exception encountered: " + e.getMessage()); diff --git a/engine/userdata/pom.xml b/engine/userdata/pom.xml index 2e00ebd97867..75475b2af183 100644 --- a/engine/userdata/pom.xml +++ b/engine/userdata/pom.xml @@ -43,5 +43,11 @@ activation 1.1.1 + + org.apache.cloudstack + cloud-engine-components-api + 4.19.0.0-SNAPSHOT + compile + diff --git a/engine/userdata/src/main/java/org/apache/cloudstack/userdata/UserDataManagerImpl.java b/engine/userdata/src/main/java/org/apache/cloudstack/userdata/UserDataManagerImpl.java index b2ee9dfd6079..91f24fe70458 100644 --- a/engine/userdata/src/main/java/org/apache/cloudstack/userdata/UserDataManagerImpl.java +++ b/engine/userdata/src/main/java/org/apache/cloudstack/userdata/UserDataManagerImpl.java @@ -16,17 +16,29 @@ // under the License. package org.apache.cloudstack.userdata; -import com.cloud.utils.component.ManagerBase; -import com.cloud.utils.exception.CloudRuntimeException; +import java.io.UnsupportedEncodingException; +import java.net.URLDecoder; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +import org.apache.cloudstack.api.BaseCmd; import org.apache.cloudstack.framework.config.ConfigKey; import org.apache.commons.codec.binary.Base64; import org.apache.commons.lang3.StringUtils; -import java.util.HashMap; -import java.util.List; -import java.util.Map; +import com.cloud.configuration.ConfigurationManager; +import com.cloud.exception.InvalidParameterValueException; +import com.cloud.utils.component.ManagerBase; +import com.cloud.utils.exception.CloudRuntimeException; public class UserDataManagerImpl extends ManagerBase implements UserDataManager { + + + private static final int MAX_USER_DATA_LENGTH_BYTES = 2048; + private static final int MAX_HTTP_GET_LENGTH = 2 * MAX_USER_DATA_LENGTH_BYTES; + private static final int NUM_OF_2K_BLOCKS = 512; + private static final int MAX_HTTP_POST_LENGTH = NUM_OF_2K_BLOCKS * MAX_USER_DATA_LENGTH_BYTES; private List userDataProviders; private static Map userDataProvidersMap = new HashMap<>(); @@ -71,12 +83,56 @@ protected UserDataProvider getUserdataProvider(String name) { @Override public String concatenateUserData(String userdata1, String userdata2, String userdataProvider) { - byte[] userdata1Bytes = Base64.decodeBase64(userdata1.getBytes()); - byte[] userdata2Bytes = Base64.decodeBase64(userdata2.getBytes()); - String userData1Str = new String(userdata1Bytes); - String userData2Str = new String(userdata2Bytes); UserDataProvider provider = getUserdataProvider(userdataProvider); - String appendUserData = provider.appendUserData(userData1Str, userData2Str); + String appendUserData = provider.appendUserData(userdata1, userdata2); return Base64.encodeBase64String(appendUserData.getBytes()); } + + @Override + public String validateUserData(String userData, BaseCmd.HTTPMethod httpmethod) { + byte[] decodedUserData = null; + if (userData != null) { + + if (userData.contains("%")) { + try { + userData = URLDecoder.decode(userData, "UTF-8"); + } catch (UnsupportedEncodingException e) { + throw new InvalidParameterValueException("Url decoding of userdata failed."); + } + } + + if (!Base64.isBase64(userData)) { + throw new InvalidParameterValueException("User data is not base64 encoded"); + } + // If GET, use 4K. If POST, support up to 1M. + if (httpmethod.equals(BaseCmd.HTTPMethod.GET)) { + decodedUserData = validateAndDecodeByHTTPMethod(userData, MAX_HTTP_GET_LENGTH, BaseCmd.HTTPMethod.GET); + } else if (httpmethod.equals(BaseCmd.HTTPMethod.POST)) { + decodedUserData = validateAndDecodeByHTTPMethod(userData, MAX_HTTP_POST_LENGTH, BaseCmd.HTTPMethod.POST); + } + + if (decodedUserData == null || decodedUserData.length < 1) { + throw new InvalidParameterValueException("User data is too short"); + } + // Re-encode so that the '=' paddings are added if necessary since 'isBase64' does not require it, but python does on the VR. + return Base64.encodeBase64String(decodedUserData); + } + return null; + } + + private byte[] validateAndDecodeByHTTPMethod(String userData, int maxHTTPLength, BaseCmd.HTTPMethod httpMethod) { + byte[] decodedUserData = null; + + if (userData.length() >= maxHTTPLength) { + throw new InvalidParameterValueException(String.format("User data is too long for an http %s request", httpMethod.toString())); + } + if (userData.length() > ConfigurationManager.VM_USERDATA_MAX_LENGTH.value()) { + throw new InvalidParameterValueException("User data has exceeded configurable max length : " + ConfigurationManager.VM_USERDATA_MAX_LENGTH.value()); + } + decodedUserData = Base64.decodeBase64(userData.getBytes()); + if (decodedUserData.length > maxHTTPLength) { + throw new InvalidParameterValueException(String.format("User data is too long for http %s request", httpMethod.toString())); + } + return decodedUserData; + } } diff --git a/engine/userdata/src/main/java/org/apache/cloudstack/userdata/UserDataProvider.java b/engine/userdata/src/main/java/org/apache/cloudstack/userdata/UserDataProvider.java index 9ac577b54ef7..4cdcd516695d 100644 --- a/engine/userdata/src/main/java/org/apache/cloudstack/userdata/UserDataProvider.java +++ b/engine/userdata/src/main/java/org/apache/cloudstack/userdata/UserDataProvider.java @@ -21,8 +21,8 @@ public interface UserDataProvider { /** * Append user data into a single user data. - * NOTE: userData1 and userData2 are decoded user data strings + * NOTE: userData1 and userData2 are Base64 encoded user data strings * @return a non-encrypted string containing both user data inputs */ - String appendUserData(String userData1, String userData2); + String appendUserData(String encodedUserData1, String encodedUserData2); } diff --git a/engine/userdata/src/test/java/org/apache/cloudstack/userdata/UserDataManagerImplTest.java b/engine/userdata/src/test/java/org/apache/cloudstack/userdata/UserDataManagerImplTest.java new file mode 100644 index 000000000000..67e7b38e37d0 --- /dev/null +++ b/engine/userdata/src/test/java/org/apache/cloudstack/userdata/UserDataManagerImplTest.java @@ -0,0 +1,59 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. +package org.apache.cloudstack.userdata; + +import static org.junit.Assert.assertEquals; + +import java.nio.charset.StandardCharsets; + +import org.apache.cloudstack.api.BaseCmd; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.mockito.InjectMocks; +import org.mockito.Spy; +import org.mockito.junit.MockitoJUnitRunner; + +@RunWith(MockitoJUnitRunner.class) +public class UserDataManagerImplTest { + + @Spy + @InjectMocks + private UserDataManagerImpl userDataManager; + + @Test + public void testValidateBase64WithoutPadding() { + // fo should be encoded in base64 either as Zm8 or Zm8= + String encodedUserdata = "Zm8"; + String encodedUserdataWithPadding = "Zm8="; + + // Verify that we accept both but return the padded version + assertEquals("validate return the value with padding", encodedUserdataWithPadding, userDataManager.validateUserData(encodedUserdata, BaseCmd.HTTPMethod.GET)); + assertEquals("validate return the value with padding", encodedUserdataWithPadding, userDataManager.validateUserData(encodedUserdataWithPadding, BaseCmd.HTTPMethod.GET)); + } + + @Test + public void testValidateUrlEncodedBase64() { + // fo should be encoded in base64 either as Zm8 or Zm8= + String encodedUserdata = "Zm+8/w8="; + String urlEncodedUserdata = java.net.URLEncoder.encode(encodedUserdata, StandardCharsets.UTF_8); + + // Verify that we accept both but return the padded version + assertEquals("validate return the value with padding", encodedUserdata, userDataManager.validateUserData(encodedUserdata, BaseCmd.HTTPMethod.GET)); + assertEquals("validate return the value with padding", encodedUserdata, userDataManager.validateUserData(urlEncodedUserdata, BaseCmd.HTTPMethod.GET)); + } + +} diff --git a/server/src/main/java/com/cloud/configuration/ConfigurationManagerImpl.java b/server/src/main/java/com/cloud/configuration/ConfigurationManagerImpl.java index c89b0e2e3b12..890fb1195e2e 100644 --- a/server/src/main/java/com/cloud/configuration/ConfigurationManagerImpl.java +++ b/server/src/main/java/com/cloud/configuration/ConfigurationManagerImpl.java @@ -460,7 +460,6 @@ public class ConfigurationManagerImpl extends ManagerBase implements Configurati protected Set configValuesForValidation; private Set weightBasedParametersForValidation; private Set overprovisioningFactorsForValidation; - public static final String VM_USERDATA_MAX_LENGTH_STRING = "vm.userdata.max.length"; public static final ConfigKey SystemVMUseLocalStorage = new ConfigKey(Boolean.class, "system.vm.use.local.storage", "Advanced", "false", "Indicates whether to use local storage pools or shared storage pools for system VMs.", false, ConfigKey.Scope.Zone, null); @@ -491,8 +490,6 @@ public class ConfigurationManagerImpl extends ManagerBase implements Configurati public static ConfigKey VM_SERVICE_OFFERING_MAX_RAM_SIZE = new ConfigKey("Advanced", Integer.class, "vm.serviceoffering.ram.size.max", "0", "Maximum RAM size in " + "MB for vm service offering. If 0 - no limitation", true); - public static final ConfigKey VM_USERDATA_MAX_LENGTH = new ConfigKey("Advanced", Integer.class, VM_USERDATA_MAX_LENGTH_STRING, "32768", - "Max length of vm userdata after base64 decoding. Default is 32768 and maximum is 1048576", true); public static final ConfigKey MIGRATE_VM_ACROSS_CLUSTERS = new ConfigKey(Boolean.class, "migrate.vm.across.clusters", "Advanced", "false", "Indicates whether the VM can be migrated to different cluster if no host is found in same cluster",true, ConfigKey.Scope.Zone, null); diff --git a/server/src/main/java/com/cloud/server/ManagementServerImpl.java b/server/src/main/java/com/cloud/server/ManagementServerImpl.java index 913063c8e639..c16dc4eb2f47 100644 --- a/server/src/main/java/com/cloud/server/ManagementServerImpl.java +++ b/server/src/main/java/com/cloud/server/ManagementServerImpl.java @@ -16,12 +16,7 @@ // under the License. package com.cloud.server; -import static com.cloud.configuration.ConfigurationManagerImpl.VM_USERDATA_MAX_LENGTH; -import static com.cloud.vm.UserVmManager.MAX_USER_DATA_LENGTH_BYTES; - -import java.io.UnsupportedEncodingException; import java.lang.reflect.Field; -import java.net.URLDecoder; import java.util.ArrayList; import java.util.Arrays; import java.util.Calendar; @@ -56,7 +51,6 @@ import org.apache.cloudstack.annotation.dao.AnnotationDao; import org.apache.cloudstack.api.ApiCommandResourceType; import org.apache.cloudstack.api.ApiConstants; -import org.apache.cloudstack.api.BaseCmd; import org.apache.cloudstack.api.command.admin.account.CreateAccountCmd; import org.apache.cloudstack.api.command.admin.account.DeleteAccountCmd; import org.apache.cloudstack.api.command.admin.account.DisableAccountCmd; @@ -611,6 +605,7 @@ import org.apache.cloudstack.storage.datastore.db.TemplateDataStoreVO; import org.apache.cloudstack.storage.datastore.db.VolumeDataStoreDao; import org.apache.cloudstack.storage.datastore.db.VolumeDataStoreVO; +import org.apache.cloudstack.userdata.UserDataManager; import org.apache.cloudstack.utils.CloudStackVersion; import org.apache.cloudstack.utils.identity.ManagementServerNode; import org.apache.commons.codec.binary.Base64; @@ -620,13 +615,13 @@ import com.cloud.agent.AgentManager; import com.cloud.agent.api.Answer; -import com.cloud.agent.api.Command; import com.cloud.agent.api.CheckGuestOsMappingAnswer; import com.cloud.agent.api.CheckGuestOsMappingCommand; -import com.cloud.agent.api.GetVncPortAnswer; -import com.cloud.agent.api.GetVncPortCommand; +import com.cloud.agent.api.Command; import com.cloud.agent.api.GetHypervisorGuestOsNamesAnswer; import com.cloud.agent.api.GetHypervisorGuestOsNamesCommand; +import com.cloud.agent.api.GetVncPortAnswer; +import com.cloud.agent.api.GetVncPortCommand; import com.cloud.agent.api.PatchSystemVmAnswer; import com.cloud.agent.api.PatchSystemVmCommand; import com.cloud.agent.api.proxy.AllowConsoleAccessCommand; @@ -696,7 +691,6 @@ import com.cloud.host.dao.HostDao; import com.cloud.host.dao.HostDetailsDao; import com.cloud.host.dao.HostTagsDao; -import com.cloud.hypervisor.Hypervisor; import com.cloud.hypervisor.Hypervisor.HypervisorType; import com.cloud.hypervisor.HypervisorCapabilities; import com.cloud.hypervisor.HypervisorCapabilitiesVO; @@ -779,6 +773,7 @@ import com.cloud.utils.crypt.DBEncryptionUtil; import com.cloud.utils.db.DB; import com.cloud.utils.db.Filter; +import com.cloud.utils.db.GenericSearchBuilder; import com.cloud.utils.db.GlobalLock; import com.cloud.utils.db.JoinBuilder; import com.cloud.utils.db.JoinBuilder.JoinType; @@ -788,7 +783,6 @@ import com.cloud.utils.db.TransactionCallbackNoReturn; import com.cloud.utils.db.TransactionStatus; import com.cloud.utils.db.UUIDManager; -import com.cloud.utils.db.GenericSearchBuilder; import com.cloud.utils.exception.CloudRuntimeException; import com.cloud.utils.fsm.StateMachine2; import com.cloud.utils.net.MacAddress; @@ -827,10 +821,6 @@ public class ManagementServerImpl extends ManagerBase implements ManagementServe static final ConfigKey humanReadableSizes = new ConfigKey("Advanced", Boolean.class, "display.human.readable.sizes", "true", "Enables outputting human readable byte sizes to logs and usage records.", false, ConfigKey.Scope.Global); public static final ConfigKey customCsIdentifier = new ConfigKey("Advanced", String.class, "custom.cs.identifier", UUID.randomUUID().toString().split("-")[0].substring(4), "Custom identifier for the cloudstack installation", true, ConfigKey.Scope.Global); private static final VirtualMachine.Type []systemVmTypes = { VirtualMachine.Type.SecondaryStorageVm, VirtualMachine.Type.ConsoleProxy}; - - private static final int MAX_HTTP_GET_LENGTH = 2 * MAX_USER_DATA_LENGTH_BYTES; - private static final int NUM_OF_2K_BLOCKS = 512; - private static final int MAX_HTTP_POST_LENGTH = NUM_OF_2K_BLOCKS * MAX_USER_DATA_LENGTH_BYTES; private static final List LIVE_MIGRATION_SUPPORTING_HYPERVISORS = List.of(HypervisorType.Hyperv, HypervisorType.KVM, HypervisorType.LXC, HypervisorType.Ovm, HypervisorType.Ovm3, HypervisorType.Simulator, HypervisorType.VMware, HypervisorType.XenServer); @@ -982,6 +972,8 @@ public class ManagementServerImpl extends ManagerBase implements ManagementServe protected VMTemplateDao templateDao; @Inject protected AnnotationDao annotationDao; + @Inject + UserDataManager userDataManager; private LockControllerListener _lockControllerListener; private final ScheduledExecutorService _eventExecutor = Executors.newScheduledThreadPool(1, new NamedThreadFactory("EventChecker")); @@ -999,7 +991,7 @@ public class ManagementServerImpl extends ManagerBase implements ManagementServe protected List _planners; - private final List supportedHypervisors = new ArrayList(); + private final List supportedHypervisors = new ArrayList(); public List getPlanners() { return _planners; @@ -4711,58 +4703,11 @@ public UserData registerUserData(final RegisterUserDataCmd cmd) { String userdata = cmd.getUserData(); final String params = cmd.getParams(); - userdata = validateUserData(userdata, cmd.getHttpMethod()); + userdata = userDataManager.validateUserData(userdata, cmd.getHttpMethod()); return createAndSaveUserData(name, userdata, params, owner); } - private String validateUserData(String userData, BaseCmd.HTTPMethod httpmethod) { - byte[] decodedUserData = null; - if (userData != null) { - - if (userData.contains("%")) { - try { - userData = URLDecoder.decode(userData, "UTF-8"); - } catch (UnsupportedEncodingException e) { - throw new InvalidParameterValueException("Url decoding of userdata failed."); - } - } - - if (!Base64.isBase64(userData)) { - throw new InvalidParameterValueException("User data is not base64 encoded"); - } - // If GET, use 4K. If POST, support up to 1M. - if (httpmethod.equals(BaseCmd.HTTPMethod.GET)) { - decodedUserData = validateAndDecodeByHTTPmethod(userData, MAX_HTTP_GET_LENGTH, BaseCmd.HTTPMethod.GET); - } else if (httpmethod.equals(BaseCmd.HTTPMethod.POST)) { - decodedUserData = validateAndDecodeByHTTPmethod(userData, MAX_HTTP_POST_LENGTH, BaseCmd.HTTPMethod.POST); - } - - if (decodedUserData == null || decodedUserData.length < 1) { - throw new InvalidParameterValueException("User data is too short"); - } - // Re-encode so that the '=' paddings are added if necessary since 'isBase64' does not require it, but python does on the VR. - return Base64.encodeBase64String(decodedUserData); - } - return null; - } - - private byte[] validateAndDecodeByHTTPmethod(String userData, int maxHTTPlength, BaseCmd.HTTPMethod httpMethod) { - byte[] decodedUserData = null; - - if (userData.length() >= maxHTTPlength) { - throw new InvalidParameterValueException(String.format("User data is too long for an http %s request", httpMethod.toString())); - } - if (userData.length() > VM_USERDATA_MAX_LENGTH.value()) { - throw new InvalidParameterValueException("User data has exceeded configurable max length : " + VM_USERDATA_MAX_LENGTH.value()); - } - decodedUserData = Base64.decodeBase64(userData.getBytes()); - if (decodedUserData.length > maxHTTPlength) { - throw new InvalidParameterValueException(String.format("User data is too long for http %s request", httpMethod.toString())); - } - return decodedUserData; - } - /** * @param cmd * @param owner diff --git a/server/src/main/java/com/cloud/vm/UserVmManager.java b/server/src/main/java/com/cloud/vm/UserVmManager.java index 39f1e5d2d283..6dd9c27e5809 100644 --- a/server/src/main/java/com/cloud/vm/UserVmManager.java +++ b/server/src/main/java/com/cloud/vm/UserVmManager.java @@ -58,8 +58,6 @@ public interface UserVmManager extends UserVmService { "Destroys the VM's root volume when the VM is destroyed.", true, ConfigKey.Scope.Domain); - static final int MAX_USER_DATA_LENGTH_BYTES = 2048; - public static final String CKS_NODE = "cksnode"; /** diff --git a/server/src/main/java/com/cloud/vm/UserVmManagerImpl.java b/server/src/main/java/com/cloud/vm/UserVmManagerImpl.java index d223fddd8752..159c230b8764 100644 --- a/server/src/main/java/com/cloud/vm/UserVmManagerImpl.java +++ b/server/src/main/java/com/cloud/vm/UserVmManagerImpl.java @@ -16,7 +16,6 @@ // under the License. package com.cloud.vm; -import static com.cloud.configuration.ConfigurationManagerImpl.VM_USERDATA_MAX_LENGTH; import static com.cloud.utils.NumbersUtil.toHumanReadableSize; import java.io.IOException; @@ -127,7 +126,6 @@ import org.apache.cloudstack.utils.bytescale.ByteScaleUtils; import org.apache.cloudstack.utils.security.ParserUtils; import org.apache.cloudstack.vm.schedule.VMScheduleManager; -import org.apache.commons.codec.binary.Base64; import org.apache.commons.collections.CollectionUtils; import org.apache.commons.collections.MapUtils; import org.apache.commons.lang.math.NumberUtils; @@ -603,10 +601,6 @@ public class UserVmManagerImpl extends ManagerBase implements UserVmManager, Vir protected static long ROOT_DEVICE_ID = 0; - private static final int MAX_HTTP_GET_LENGTH = 2 * MAX_USER_DATA_LENGTH_BYTES; - private static final int NUM_OF_2K_BLOCKS = 512; - private static final int MAX_HTTP_POST_LENGTH = NUM_OF_2K_BLOCKS * MAX_USER_DATA_LENGTH_BYTES; - @Inject private OrchestrationService _orchSrvc; @@ -947,7 +941,7 @@ public UserVm resetVMUserData(ResetVMUserDataCmd cmd) throws ResourceUnavailable userDataDetails = cmd.getUserdataDetails().toString(); } userData = finalizeUserData(userData, userDataId, template); - userData = validateUserData(userData, cmd.getHttpMethod()); + userData = userDataManager.validateUserData(userData, cmd.getHttpMethod()); userVm.setUserDataId(userDataId); userVm.setUserData(userData); @@ -2950,7 +2944,7 @@ public UserVm updateVirtualMachine(long id, String displayName, String group, Bo if (userData != null) { // check and replace newlines userData = userData.replace("\\n", ""); - userData = validateUserData(userData, httpMethod); + userData = userDataManager.validateUserData(userData, httpMethod); // update userData on domain router. updateUserdata = true; } else { @@ -4073,7 +4067,7 @@ private UserVm getUncheckedUserVmResource(DataCenter zone, String hostName, Stri _accountMgr.checkAccess(owner, AccessType.UseEntry, false, template); // check if the user data is correct - userData = validateUserData(userData, httpmethod); + userData = userDataManager.validateUserData(userData, httpmethod); // Find an SSH public key corresponding to the key pair name, if one is // given @@ -4766,55 +4760,6 @@ public void doInTransactionWithoutResult(TransactionStatus status) { } } - protected String validateUserData(String userData, HTTPMethod httpmethod) { - byte[] decodedUserData = null; - if (userData != null) { - - if (userData.contains("%")) { - try { - userData = URLDecoder.decode(userData, "UTF-8"); - } catch (UnsupportedEncodingException e) { - throw new InvalidParameterValueException("Url decoding of userdata failed."); - } - } - - if (!Base64.isBase64(userData)) { - throw new InvalidParameterValueException("User data is not base64 encoded"); - } - // If GET, use 4K. If POST, support up to 1M. - if (httpmethod.equals(HTTPMethod.GET)) { - if (userData.length() >= MAX_HTTP_GET_LENGTH) { - throw new InvalidParameterValueException("User data is too long for an http GET request"); - } - if (userData.length() > VM_USERDATA_MAX_LENGTH.value()) { - throw new InvalidParameterValueException("User data has exceeded configurable max length : " + VM_USERDATA_MAX_LENGTH.value()); - } - decodedUserData = Base64.decodeBase64(userData.getBytes()); - if (decodedUserData.length > MAX_HTTP_GET_LENGTH) { - throw new InvalidParameterValueException("User data is too long for GET request"); - } - } else if (httpmethod.equals(HTTPMethod.POST)) { - if (userData.length() >= MAX_HTTP_POST_LENGTH) { - throw new InvalidParameterValueException("User data is too long for an http POST request"); - } - if (userData.length() > VM_USERDATA_MAX_LENGTH.value()) { - throw new InvalidParameterValueException("User data has exceeded configurable max length : " + VM_USERDATA_MAX_LENGTH.value()); - } - decodedUserData = Base64.decodeBase64(userData.getBytes()); - if (decodedUserData.length > MAX_HTTP_POST_LENGTH) { - throw new InvalidParameterValueException("User data is too long for POST request"); - } - } - - if (decodedUserData == null || decodedUserData.length < 1) { - throw new InvalidParameterValueException("User data is too short"); - } - // Re-encode so that the '=' paddings are added if necessary since 'isBase64' does not require it, but python does on the VR. - return Base64.encodeBase64String(decodedUserData); - } - return null; - } - @Override @ActionEvent(eventType = EventTypes.EVENT_VM_CREATE, eventDescription = "deploying Vm", async = true) public UserVm startVirtualMachine(DeployVMCmd cmd) throws ResourceUnavailableException, InsufficientCapacityException, ConcurrentOperationException, ResourceAllocationException { @@ -5852,6 +5797,7 @@ public UserVm createVirtualMachine(DeployVMCmd cmd) throws InsufficientCapacityE } String userData = cmd.getUserData(); + userData = userDataManager.validateUserData(userData, cmd.getHttpMethod()); Long userDataId = cmd.getUserdataId(); String userDataDetails = null; if (MapUtils.isNotEmpty(cmd.getUserdataDetails())) { diff --git a/server/src/test/java/com/cloud/server/ManagementServerImplTest.java b/server/src/test/java/com/cloud/server/ManagementServerImplTest.java index cf8df1ad3721..1de5b256dbd7 100644 --- a/server/src/test/java/com/cloud/server/ManagementServerImplTest.java +++ b/server/src/test/java/com/cloud/server/ManagementServerImplTest.java @@ -22,6 +22,35 @@ import static org.mockito.Mockito.lenient; import static org.mockito.Mockito.when; +import java.util.ArrayList; +import java.util.List; + +import org.apache.cloudstack.annotation.dao.AnnotationDao; +import org.apache.cloudstack.api.ApiConstants; +import org.apache.cloudstack.api.BaseCmd; +import org.apache.cloudstack.api.command.user.address.ListPublicIpAddressesCmd; +import org.apache.cloudstack.api.command.user.ssh.RegisterSSHKeyPairCmd; +import org.apache.cloudstack.api.command.user.userdata.DeleteUserDataCmd; +import org.apache.cloudstack.api.command.user.userdata.ListUserDataCmd; +import org.apache.cloudstack.api.command.user.userdata.RegisterUserDataCmd; +import org.apache.cloudstack.context.CallContext; +import org.apache.cloudstack.framework.config.ConfigKey; +import org.apache.cloudstack.userdata.UserDataManager; +import org.junit.After; +import org.junit.Assert; +import org.junit.Before; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.mockito.Mock; +import org.mockito.Mockito; +import org.mockito.MockitoAnnotations; +import org.mockito.Spy; +import org.powermock.api.mockito.PowerMockito; +import org.powermock.core.classloader.annotations.PrepareForTest; +import org.powermock.modules.junit4.PowerMockRunner; +import org.powermock.reflect.Whitebox; +import org.springframework.test.util.ReflectionTestUtils; + import com.cloud.dc.Vlan.VlanType; import com.cloud.exception.InvalidParameterValueException; import com.cloud.host.DetailVO; @@ -49,37 +78,8 @@ import com.cloud.utils.exception.CloudRuntimeException; import com.cloud.vm.UserVmDetailVO; import com.cloud.vm.UserVmVO; -import com.cloud.vm.dao.UserVmDetailsDao; import com.cloud.vm.dao.UserVmDao; - -import org.apache.cloudstack.annotation.dao.AnnotationDao; -import org.apache.cloudstack.api.ApiConstants; -import org.apache.cloudstack.api.BaseCmd; -import org.apache.cloudstack.api.command.user.address.ListPublicIpAddressesCmd; -import org.apache.cloudstack.api.command.user.ssh.RegisterSSHKeyPairCmd; -import org.apache.cloudstack.api.command.user.userdata.DeleteUserDataCmd; -import org.apache.cloudstack.api.command.user.userdata.ListUserDataCmd; -import org.apache.cloudstack.api.command.user.userdata.RegisterUserDataCmd; -import org.apache.cloudstack.context.CallContext; -import org.apache.cloudstack.framework.config.ConfigKey; - -import org.junit.After; -import org.junit.Assert; -import org.junit.Before; -import org.junit.Test; -import org.junit.runner.RunWith; -import org.mockito.Mock; -import org.mockito.Mockito; -import org.mockito.MockitoAnnotations; -import org.mockito.Spy; -import org.powermock.api.mockito.PowerMockito; -import org.powermock.core.classloader.annotations.PrepareForTest; -import org.powermock.modules.junit4.PowerMockRunner; -import org.powermock.reflect.Whitebox; -import org.springframework.test.util.ReflectionTestUtils; - -import java.util.ArrayList; -import java.util.List; +import com.cloud.vm.dao.UserVmDetailsDao; @RunWith(PowerMockRunner.class) @PrepareForTest(CallContext.class) @@ -121,6 +121,9 @@ public class ManagementServerImplTest { @Mock UserVmDao _userVmDao; + @Mock + UserDataManager userDataManager; + @Spy ManagementServerImpl spy = new ManagementServerImpl(); @@ -145,6 +148,7 @@ public void setup() { spy.annotationDao = annotationDao; spy._UserVmDetailsDao = userVmDetailsDao; spy._detailsDao = hostDetailsDao; + spy.userDataManager = userDataManager; } @After @@ -304,13 +308,15 @@ public void testSuccessfulRegisterUserdata() { when(callContextMock.getCallingAccount()).thenReturn(account); when(_accountMgr.finalizeOwner(nullable(Account.class), nullable(String.class), nullable(Long.class), nullable(Long.class))).thenReturn(account); + String testUserData = "testUserdata"; RegisterUserDataCmd cmd = Mockito.mock(RegisterUserDataCmd.class); - when(cmd.getUserData()).thenReturn("testUserdata"); + when(cmd.getUserData()).thenReturn(testUserData); when(cmd.getName()).thenReturn("testName"); when(cmd.getHttpMethod()).thenReturn(BaseCmd.HTTPMethod.GET); when(_userDataDao.findByName(account.getAccountId(), account.getDomainId(), "testName")).thenReturn(null); - when(_userDataDao.findByUserData(account.getAccountId(), account.getDomainId(), "testUserdata")).thenReturn(null); + when(_userDataDao.findByUserData(account.getAccountId(), account.getDomainId(), testUserData)).thenReturn(null); + when(userDataManager.validateUserData(testUserData,BaseCmd.HTTPMethod.GET)).thenReturn(testUserData); UserData userData = spy.registerUserData(cmd); Assert.assertEquals("testName", userData.getName()); diff --git a/server/src/test/java/com/cloud/vm/UserVmManagerImplTest.java b/server/src/test/java/com/cloud/vm/UserVmManagerImplTest.java index f91b52b867b2..ef001906d981 100644 --- a/server/src/test/java/com/cloud/vm/UserVmManagerImplTest.java +++ b/server/src/test/java/com/cloud/vm/UserVmManagerImplTest.java @@ -836,10 +836,13 @@ public void testResetVMUserDataSuccessResetWithUserdata() { when(templateDao.findByIdIncludingRemoved(2L)).thenReturn(template); when(template.getUserDataId()).thenReturn(null); - when(cmd.getUserData()).thenReturn("testUserdata"); + String testUserData = "testUserdata"; + when(cmd.getUserData()).thenReturn(testUserData); when(cmd.getUserdataId()).thenReturn(null); when(cmd.getHttpMethod()).thenReturn(HTTPMethod.GET); + when(userDataManager.validateUserData(testUserData, HTTPMethod.GET)).thenReturn(testUserData); + try { doNothing().when(userVmManagerImpl).updateUserData(userVmVO); userVmManagerImpl.resetVMUserData(cmd); @@ -873,12 +876,15 @@ public void testResetVMUserDataSuccessResetWithUserdataId() { when(templateDao.findByIdIncludingRemoved(2L)).thenReturn(template); when(template.getUserDataId()).thenReturn(null); + String testUserData = "testUserdata"; when(cmd.getUserdataId()).thenReturn(1L); UserDataVO apiUserDataVO = Mockito.mock(UserDataVO.class); when(userDataDao.findById(1L)).thenReturn(apiUserDataVO); - when(apiUserDataVO.getUserData()).thenReturn("testUserdata"); + when(apiUserDataVO.getUserData()).thenReturn(testUserData); when(cmd.getHttpMethod()).thenReturn(HTTPMethod.GET); + when(userDataManager.validateUserData(testUserData, HTTPMethod.GET)).thenReturn(testUserData); + try { doNothing().when(userVmManagerImpl).updateUserData(userVmVO); userVmManagerImpl.resetVMUserData(cmd); diff --git a/server/src/test/java/com/cloud/vm/UserVmManagerTest.java b/server/src/test/java/com/cloud/vm/UserVmManagerTest.java index 7cc2c8a6be13..a0ad32153688 100644 --- a/server/src/test/java/com/cloud/vm/UserVmManagerTest.java +++ b/server/src/test/java/com/cloud/vm/UserVmManagerTest.java @@ -18,7 +18,6 @@ package com.cloud.vm; import static org.hamcrest.Matchers.instanceOf; -import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertFalse; import static org.junit.Assert.assertThat; import static org.junit.Assert.assertTrue; @@ -38,7 +37,6 @@ import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; -import java.io.UnsupportedEncodingException; import java.lang.reflect.Field; import java.util.ArrayList; import java.util.HashMap; @@ -48,7 +46,6 @@ import org.apache.cloudstack.acl.ControlledEntity; import org.apache.cloudstack.acl.SecurityChecker.AccessType; -import org.apache.cloudstack.api.BaseCmd; import org.apache.cloudstack.api.command.admin.vm.AssignVMCmd; import org.apache.cloudstack.api.command.user.vm.RestoreVMCmd; import org.apache.cloudstack.api.command.user.vm.ScaleVMCmd; @@ -841,26 +838,4 @@ public void testPersistDeviceBusInfo() { _userVmMgr.persistDeviceBusInfo(_vmMock, "lsilogic"); verify(_vmDao, times(1)).saveDetails(any(UserVmVO.class)); } - - @Test - public void testValideBase64WithoutPadding() { - // fo should be encoded in base64 either as Zm8 or Zm8= - String encodedUserdata = "Zm8"; - String encodedUserdataWithPadding = "Zm8="; - - // Verify that we accept both but return the padded version - assertTrue("validate return the value with padding", encodedUserdataWithPadding.equals(_userVmMgr.validateUserData(encodedUserdata, BaseCmd.HTTPMethod.GET))); - assertTrue("validate return the value with padding", encodedUserdataWithPadding.equals(_userVmMgr.validateUserData(encodedUserdataWithPadding, BaseCmd.HTTPMethod.GET))); - } - - @Test - public void testValidateUrlEncodedBase64() throws UnsupportedEncodingException { - // fo should be encoded in base64 either as Zm8 or Zm8= - String encodedUserdata = "Zm+8/w8="; - String urlEncodedUserdata = java.net.URLEncoder.encode(encodedUserdata, "UTF-8"); - - // Verify that we accept both but return the padded version - assertEquals("validate return the value with padding", encodedUserdata, _userVmMgr.validateUserData(encodedUserdata, BaseCmd.HTTPMethod.GET)); - assertEquals("validate return the value with padding", encodedUserdata, _userVmMgr.validateUserData(urlEncodedUserdata, BaseCmd.HTTPMethod.GET)); - } } diff --git a/test/integration/smoke/test_register_userdata.py b/test/integration/smoke/test_register_userdata.py index 5c954a876ec0..c89d08e63e83 100644 --- a/test/integration/smoke/test_register_userdata.py +++ b/test/integration/smoke/test_register_userdata.py @@ -31,6 +31,8 @@ from marvin.lib.utils import (validateList, cleanup_resources) from nose.plugins.attrib import attr from marvin.codes import PASS,FAIL +import base64 +import email from marvin.lib.common import (get_domain, get_template) @@ -592,24 +594,20 @@ def test_deploy_vm_with_registered_userdata_with_override_policy_append(self): userdata and configured to VM as a multipart MIME userdata. Verify the same by SSH into VM. """ - # #!/bin/bash - # date > /provisioned + shellscript_userdata = str("#!/bin/bash\ndate > /provisioned") self.apiUserdata = UserData.register( self.apiclient, name="ApiUserdata", - userdata="IyEvYmluL2Jhc2gKZGF0ZSA+IC9wcm92aXNpb25lZA==", + userdata=base64.encodebytes(shellscript_userdata.encode()).decode(), account=self.account.name, domainid=self.account.domainid ) - # #cloud-config - # password: atomic - # chpasswd: { expire: False } - # ssh_pwauth: True + cloudconfig_userdata = str("#cloud-config\npassword: atomic\nchpasswd: { expire: False }\nssh_pwauth: True") self.templateUserdata = UserData.register( self.apiclient, name="TemplateUserdata", - userdata="I2Nsb3VkLWNvbmZpZwpwYXNzd29yZDogYXRvbWljCmNocGFzc3dkOiB7IGV4cGlyZTogRmFsc2UgfQpzc2hfcHdhdXRoOiBUcnVl", + userdata=base64.encodebytes(cloudconfig_userdata.encode()).decode(), account=self.account.name, domainid=self.account.domainid ) @@ -707,9 +705,32 @@ def test_deploy_vm_with_registered_userdata_with_override_policy_append(self): res = ssh.execute(cmd) self.debug("Verifying userdata in the VR") self.assertTrue( - "Content-Type: multipart" in str(res[2]), + res is not None and len(res) > 0, + "Resultant userdata is not valid" + ) + msg = email.message_from_string('\n'.join(res)) + self.assertTrue( + msg.is_multipart(), "Failed to match multipart userdata" ) + shellscript_userdata_found = False + cloudconfig_userdata_found = False + for part in msg.get_payload(): + content_type = part.get_content_type() + payload = part.get_payload(decode=True).decode() + if "shellscript" in content_type: + shellscript_userdata_found = shellscript_userdata == payload + elif "cloud-config" in content_type: + cloudconfig_userdata_found = cloudconfig_userdata == payload + + self.assertTrue( + shellscript_userdata_found, + "Failed to find shellscript userdata in append result" + ) + self.assertTrue( + cloudconfig_userdata_found, + "Failed to find cloud-config userdata in append result" + ) @attr(tags=['advanced', 'simulator', 'basic', 'sg', 'testnow'], required_hardware=True) def test_deploy_vm_with_registered_userdata_with_override_policy_deny(self):