Skip to content

Commit

Permalink
Refactor the EncryptionDecryptionUtil
Browse files Browse the repository at this point in the history
Signed-off-by: Ryan Liang <jiallian@amazon.com>
  • Loading branch information
RyanL1997 committed Aug 22, 2023
1 parent b315559 commit 1ba378e
Show file tree
Hide file tree
Showing 5 changed files with 52 additions and 38 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -21,16 +21,22 @@

public class EncryptionDecryptionUtil {

public static String encrypt(final String secret, final String data) {
final Cipher cipher = createCipherFromSecret(secret, CipherMode.ENCRYPT);
final byte[] cipherText = createCipherText(cipher, data.getBytes(StandardCharsets.UTF_8));
return Base64.getEncoder().encodeToString(cipherText);
private final Cipher encryptCipher;
private final Cipher decryptCipher;

public EncryptionDecryptionUtil(final String secret) {
this.encryptCipher = createCipherFromSecret(secret, CipherMode.ENCRYPT);
this.decryptCipher = createCipherFromSecret(secret, CipherMode.DECRYPT);
}

public String encrypt(final String data) {
byte[] encryptedBytes = processWithCipher(data.getBytes(StandardCharsets.UTF_8), encryptCipher);
return Base64.getEncoder().encodeToString(encryptedBytes);
}

public static String decrypt(final String secret, final String encryptedString) {
final Cipher cipher = createCipherFromSecret(secret, CipherMode.DECRYPT);
final byte[] cipherText = createCipherText(cipher, Base64.getDecoder().decode(encryptedString));
return new String(cipherText, StandardCharsets.UTF_8);
public String decrypt(final String encryptedString) {
byte[] decodedBytes = Base64.getDecoder().decode(encryptedString);
return new String(processWithCipher(decodedBytes, decryptCipher), StandardCharsets.UTF_8);
}

private static Cipher createCipherFromSecret(final String secret, final CipherMode mode) {
Expand All @@ -41,15 +47,15 @@ private static Cipher createCipherFromSecret(final String secret, final CipherMo
cipher.init(mode.opmode, originalKey);
return cipher;
} catch (final Exception e) {
throw new RuntimeException("Error creating cipher from secret in mode " + mode.name());
throw new RuntimeException("Error creating cipher from secret in mode " + mode.name(), e);
}
}

private static byte[] createCipherText(final Cipher cipher, final byte[] data) {
private static byte[] processWithCipher(final byte[] data, final Cipher cipher) {
try {
return cipher.doFinal(data);
} catch (final Exception e) {
throw new RuntimeException("The cipher was unable to perform pass over data");
throw new RuntimeException("Error processing data with cipher", e);
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ public class JwtVendor {
private final JoseJwtProducer jwtProducer;
private final LongSupplier timeProvider;
private final Boolean bwcModeEnabled;
private final EncryptionDecryptionUtil encryptionDecryptionUtil;

public JwtVendor(final Settings settings, final Optional<LongSupplier> timeProvider) {
JoseJwtProducer jwtProducer = new JoseJwtProducer();
Expand All @@ -56,6 +57,7 @@ public JwtVendor(final Settings settings, final Optional<LongSupplier> timeProvi
throw new IllegalArgumentException("encryption_key cannot be null");
} else {
this.claimsEncryptionKey = settings.get("encryption_key");
this.encryptionDecryptionUtil = new EncryptionDecryptionUtil(claimsEncryptionKey);
}
if (timeProvider.isPresent()) {
this.timeProvider = timeProvider.get();
Expand Down Expand Up @@ -140,7 +142,7 @@ public String createJwt(

if (roles != null) {
String listOfRoles = String.join(",", roles);
jwtClaims.setProperty("er", EncryptionDecryptionUtil.encrypt(claimsEncryptionKey, listOfRoles));
jwtClaims.setProperty("er", encryptionDecryptionUtil.encrypt(listOfRoles));
} else {
throw new Exception("Roles cannot be null");
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -60,12 +60,15 @@ public class OnBehalfOfAuthenticator implements HTTPAuthenticator {
private final Boolean oboEnabled;
private final String clusterName;

private final EncryptionDecryptionUtil encryptionUtil;

public OnBehalfOfAuthenticator(Settings settings, String clusterName) {
String oboEnabledSetting = settings.get("enabled");
oboEnabled = oboEnabledSetting == null ? Boolean.TRUE : Boolean.valueOf(oboEnabledSetting);
encryptionKey = settings.get("encryption_key");
jwtParser = initParser(settings.get("signing_key"));
this.clusterName = clusterName;
this.encryptionUtil = new EncryptionDecryptionUtil(encryptionKey);
}

private JwtParser initParser(final String signingKey) {
Expand All @@ -84,7 +87,7 @@ private List<String> extractSecurityRolesFromClaims(Claims claims) {
String rolesClaim = "";

if (er != null) {
rolesClaim = EncryptionDecryptionUtil.decrypt(encryptionKey, er.toString());
rolesClaim = encryptionUtil.decrypt(er.toString());
} else if (dr != null) {
rolesClaim = dr.toString();
} else {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,10 @@ public void testEncryptDecrypt() {
String secret = Base64.getEncoder().encodeToString("mySecretKey12345".getBytes());
String data = "Hello, OpenSearch!";

String encryptedString = EncryptionDecryptionUtil.encrypt(secret, data);
String decryptedString = EncryptionDecryptionUtil.decrypt(secret, encryptedString);
EncryptionDecryptionUtil util = new EncryptionDecryptionUtil(secret);

String encryptedString = util.encrypt(data);
String decryptedString = util.decrypt(encryptedString);

Assert.assertEquals(data, decryptedString);
}
Expand All @@ -34,22 +36,22 @@ public void testDecryptingWithWrongKey() {
String secret2 = Base64.getEncoder().encodeToString("wrongKey1234567".getBytes());
String data = "Hello, OpenSearch!";

String encryptedString = EncryptionDecryptionUtil.encrypt(secret1, data);
EncryptionDecryptionUtil util1 = new EncryptionDecryptionUtil(secret1);
String encryptedString = util1.encrypt(data);

RuntimeException ex = Assert.assertThrows(RuntimeException.class, () -> EncryptionDecryptionUtil.decrypt(secret2, encryptedString));
EncryptionDecryptionUtil util2 = new EncryptionDecryptionUtil(secret2);
RuntimeException ex = Assert.assertThrows(RuntimeException.class, () -> util2.decrypt(encryptedString));

Assert.assertEquals("The cipher was unable to perform pass over data", ex.getMessage());
Assert.assertEquals("Error processing data with cipher", ex.getMessage());
}

@Test
public void testDecryptingCorruptedData() {
String secret = Base64.getEncoder().encodeToString("mySecretKey12345".getBytes());
String corruptedEncryptedString = "corruptedData";

RuntimeException ex = Assert.assertThrows(
RuntimeException.class,
() -> EncryptionDecryptionUtil.decrypt(secret, corruptedEncryptedString)
);
EncryptionDecryptionUtil util = new EncryptionDecryptionUtil(secret);
RuntimeException ex = Assert.assertThrows(RuntimeException.class, () -> util.decrypt(corruptedEncryptedString));

Assert.assertEquals("Last unit does not have enough valid bits", ex.getMessage());
}
Expand All @@ -59,8 +61,9 @@ public void testEncryptDecryptEmptyString() {
String secret = Base64.getEncoder().encodeToString("mySecretKey12345".getBytes());
String data = "";

String encryptedString = EncryptionDecryptionUtil.encrypt(secret, data);
String decryptedString = EncryptionDecryptionUtil.decrypt(secret, encryptedString);
EncryptionDecryptionUtil util = new EncryptionDecryptionUtil(secret);
String encryptedString = util.encrypt(data);
String decryptedString = util.decrypt(encryptedString);

Assert.assertEquals(data, decryptedString);
}
Expand All @@ -70,14 +73,16 @@ public void testEncryptNullValue() {
String secret = Base64.getEncoder().encodeToString("mySecretKey12345".getBytes());
String data = null;

EncryptionDecryptionUtil.encrypt(secret, data);
EncryptionDecryptionUtil util = new EncryptionDecryptionUtil(secret);
util.encrypt(data);
}

@Test(expected = NullPointerException.class)
public void testDecryptNullValue() {
String secret = Base64.getEncoder().encodeToString("mySecretKey12345".getBytes());
String data = null;

EncryptionDecryptionUtil.decrypt(secret, data);
EncryptionDecryptionUtil util = new EncryptionDecryptionUtil(secret);
util.decrypt(data);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -60,8 +60,8 @@ public void testCreateJwtWithRoles() throws Exception {
List<String> roles = List.of("IT", "HR");
List<String> backendRoles = List.of("Sales", "Support");
String expectedRoles = "IT,HR";
Integer expirySeconds = 300;
LongSupplier currentTime = () -> (int) 100;
int expirySeconds = 300;
LongSupplier currentTime = () -> (long) 100;
String claimsEncryptionKey = RandomStringUtils.randomAlphanumeric(16);
Settings settings = Settings.builder().put("signing_key", "abc123").put("encryption_key", claimsEncryptionKey).build();
Long expectedExp = currentTime.getAsLong() + expirySeconds;
Expand All @@ -78,8 +78,8 @@ public void testCreateJwtWithRoles() throws Exception {
Assert.assertNotNull(jwt.getClaim("iat"));
Assert.assertNotNull(jwt.getClaim("exp"));
Assert.assertEquals(expectedExp, jwt.getClaim("exp"));
Assert.assertNotEquals(expectedRoles, jwt.getClaim("er"));
Assert.assertEquals(expectedRoles, EncryptionDecryptionUtil.decrypt(claimsEncryptionKey, jwt.getClaim("er").toString()));
EncryptionDecryptionUtil encryptionUtil = new EncryptionDecryptionUtil(claimsEncryptionKey);
Assert.assertEquals(expectedRoles, encryptionUtil.decrypt(jwt.getClaim("er").toString()));
Assert.assertNull(jwt.getClaim("br"));
}

Expand All @@ -93,15 +93,13 @@ public void testCreateJwtWithBackwardsCompatibilityMode() throws Exception {
String expectedRoles = "IT,HR";
String expectedBackendRoles = "Sales,Support";

Integer expirySeconds = 300;
LongSupplier currentTime = () -> (int) 100;
int expirySeconds = 300;
LongSupplier currentTime = () -> (long) 100;
String claimsEncryptionKey = RandomStringUtils.randomAlphanumeric(16);
Settings settings = Settings.builder()
.put("signing_key", "abc123")
.put("encryption_key", claimsEncryptionKey)
// CS-SUPPRESS-SINGLE: RegexpSingleline get Extensions Settings
.put(ConfigConstants.EXTENSIONS_BWC_PLUGIN_MODE, true)
// CS-ENFORCE-SINGLE
.build();
Long expectedExp = currentTime.getAsLong() + expirySeconds;

Expand All @@ -117,8 +115,8 @@ public void testCreateJwtWithBackwardsCompatibilityMode() throws Exception {
Assert.assertNotNull(jwt.getClaim("iat"));
Assert.assertNotNull(jwt.getClaim("exp"));
Assert.assertEquals(expectedExp, jwt.getClaim("exp"));
Assert.assertNotEquals(expectedRoles, jwt.getClaim("er"));
Assert.assertEquals(expectedRoles, EncryptionDecryptionUtil.decrypt(claimsEncryptionKey, jwt.getClaim("er").toString()));
EncryptionDecryptionUtil encryptionUtil = new EncryptionDecryptionUtil(claimsEncryptionKey);
Assert.assertEquals(expectedRoles, encryptionUtil.decrypt(jwt.getClaim("er").toString()));
Assert.assertNotNull(jwt.getClaim("br"));
Assert.assertEquals(expectedBackendRoles, jwt.getClaim("br"));
}
Expand Down Expand Up @@ -170,14 +168,14 @@ public void testCreateJwtWithBadRoles() {
String subject = "admin";
String audience = "audience_0";
List<String> roles = null;
Integer expirySecond = 300;
Integer expirySeconds = 300;
String claimsEncryptionKey = RandomStringUtils.randomAlphanumeric(16);
Settings settings = Settings.builder().put("signing_key", "abc123").put("encryption_key", claimsEncryptionKey).build();
JwtVendor jwtVendor = new JwtVendor(settings, Optional.empty());

Throwable exception = Assert.assertThrows(RuntimeException.class, () -> {
try {
jwtVendor.createJwt(issuer, subject, audience, expirySecond, roles, List.of());
jwtVendor.createJwt(issuer, subject, audience, expirySeconds, roles, List.of());
} catch (Exception e) {
throw new RuntimeException(e);
}
Expand Down

0 comments on commit 1ba378e

Please sign in to comment.