Skip to content

Commit

Permalink
TKSS-1002: Check pointer value
Browse files Browse the repository at this point in the history
  • Loading branch information
johnshajiang committed Dec 26, 2024
1 parent 3578a37 commit 9a45a24
Show file tree
Hide file tree
Showing 9 changed files with 88 additions and 39 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,9 @@ private static long createCtx() {

@Override
public void close() {
nativeCrypto().sm3hmacFreeMac(pointer);
super.close();
if (pointer != 0) {
nativeCrypto().sm3hmacFreeMac(pointer);
super.close();
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ abstract class NativeRef implements Closeable {
long pointer;

NativeRef(long pointer) {
if (pointer <= 0) {
if (pointer == 0) {
throw new IllegalStateException("Create context failed");
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,9 @@ public byte[] encrypt(byte[] plaintext) throws BadPaddingException {
throw new BadPaddingException("Invalid plaintext");
}

byte[] ciphertext = nativeCrypto().sm2CipherEncrypt(pointer, plaintext);
byte[] ciphertext = pointer == 0
? null
: nativeCrypto().sm2CipherEncrypt(pointer, plaintext);
if (ciphertext == null) {
throw new BadPaddingException("Encrypt failed");
}
Expand All @@ -67,7 +69,9 @@ public byte[] decrypt(byte[] ciphertext) throws BadPaddingException {
throw new BadPaddingException("Invalid ciphertext");
}

byte[] cleartext = nativeCrypto().sm2CipherDecrypt(pointer, ciphertext);
byte[] cleartext = pointer == 0
? null
: nativeCrypto().sm2CipherDecrypt(pointer, ciphertext);
if (cleartext == null) {
throw new BadPaddingException("Decrypt failed");
}
Expand All @@ -76,7 +80,9 @@ public byte[] decrypt(byte[] ciphertext) throws BadPaddingException {

@Override
public void close() {
nativeCrypto().sm2CipherFreeCtx(pointer);
super.close();
if (pointer != 0) {
nativeCrypto().sm2CipherFreeCtx(pointer);
super.close();
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ public byte[] deriveKey(byte[] priKey, byte[] pubKey, byte[] ePriKey, byte[] id,
throw new IllegalStateException("Shared key length must be greater than 0");
}

byte[] sharedKey = nativeCrypto().sm2DeriveKey(pointer,
byte[] sharedKey = pointer == 0 ? null : nativeCrypto().sm2DeriveKey(pointer,
priKey, pubKey, ePriKey, id,
peerPubKey, peerEPubKey, peerId,
isInitiator, sharedKeyLength);
Expand All @@ -64,7 +64,9 @@ public byte[] deriveKey(byte[] priKey, byte[] pubKey, byte[] ePriKey, byte[] id,

@Override
public void close() {
nativeCrypto().sm2KeyExFreeCtx(pointer);
super.close();
if (pointer != 0) {
nativeCrypto().sm2KeyExFreeCtx(pointer);
super.close();
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,9 @@ private static long createCtx() {
// K is the private key, 32-bytes
// X and Y are the coordinates of the public key, 32-bytes
public byte[] genKeyPair() {
byte[] keyPair = nativeCrypto().sm2KeyPairGenGenKeyPair(pointer);
byte[] keyPair = pointer == 0
? null
: nativeCrypto().sm2KeyPairGenGenKeyPair(pointer);
if (keyPair == null) {
throw new IllegalStateException("Generate key pair failed");
}
Expand All @@ -65,7 +67,9 @@ public byte[] genKeyPair() {

@Override
public void close() {
nativeCrypto().sm2KeyPairGenFreeCtx(pointer);
super.close();
if (pointer != 0) {
nativeCrypto().sm2KeyPairGenFreeCtx(pointer);
super.close();
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,7 @@

import javax.crypto.BadPaddingException;

import static com.tencent.kona.crypto.provider.nativeImpl.NativeCrypto.OPENSSL_SUCCESS;
import static com.tencent.kona.crypto.provider.nativeImpl.NativeCrypto.nativeCrypto;
import static com.tencent.kona.crypto.provider.nativeImpl.NativeCrypto.*;
import static com.tencent.kona.crypto.util.Constants.*;

/**
Expand Down Expand Up @@ -81,7 +80,9 @@ public byte[] sign(byte[] message) throws BadPaddingException {
throw new BadPaddingException("Message cannot be null");
}

byte[] signature = nativeCrypto().sm2SignatureSign(pointer, message);
byte[] signature = pointer == 0
? null
: nativeCrypto().sm2SignatureSign(pointer, message);
if (signature == null) {
throw new BadPaddingException("Sign failed");
}
Expand All @@ -97,13 +98,17 @@ public boolean verify(byte[] message, byte[] signature) throws BadPaddingExcepti
throw new BadPaddingException("Invalid signature");
}

int verified = nativeCrypto().sm2SignatureVerify(pointer, message, signature);
int verified = pointer == 0
? OPENSSL_FAILURE
: nativeCrypto().sm2SignatureVerify(pointer, message, signature);
return verified == OPENSSL_SUCCESS;
}

@Override
public void close() {
nativeCrypto().sm2SignatureFreeCtx(pointer);
super.close();
if (pointer != 0) {
nativeCrypto().sm2SignatureFreeCtx(pointer);
super.close();
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -40,13 +40,15 @@ public NativeSM3(long pointer) {
public void update(byte[] data) {
Objects.requireNonNull(data);

if (nativeCrypto().sm3Update(pointer, data) != OPENSSL_SUCCESS) {
if (pointer == 0 || nativeCrypto().sm3Update(pointer, data) != OPENSSL_SUCCESS) {
throw new IllegalStateException("sm3 update operation failed");
}
}

public byte[] doFinal() {
byte[] result = nativeCrypto().sm3Final(pointer);
byte[] result = pointer == 0
? null
: nativeCrypto().sm3Final(pointer);
if (result == null) {
throw new IllegalStateException("sm3 final operation failed");
}
Expand All @@ -60,21 +62,27 @@ public byte[] doFinal(byte[] data) {

@Override
public void close() {
nativeCrypto().sm3FreeCtx(pointer);
super.close();
if (pointer != 0) {
nativeCrypto().sm3FreeCtx(pointer);
super.close();
}
}

public void reset() {
if (nativeCrypto().sm3Reset(pointer) != OPENSSL_SUCCESS) {
if (pointer == 0 || nativeCrypto().sm3Reset(pointer) != OPENSSL_SUCCESS) {
throw new IllegalStateException("sm3 reset operation failed");
}
}

@Override
protected NativeSM3 clone() {
if (pointer == 0) {
throw new IllegalStateException("Cannot clone SM3 instance");
}

long clonePointer = nativeCrypto().sm3Clone(pointer);
if (clonePointer <= 0) {
throw new IllegalStateException("sm3 clone operation failed");
throw new IllegalStateException("SM3 clone operation failed");
}
return new NativeSM3(clonePointer);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ public NativeSM3HMac(byte[] key) {
}

private static long createCtx(long macPointer, byte[] key) {
if (macPointer <= 0) {
if (macPointer == 0) {
throw new IllegalArgumentException("macPointer is invalid");
}

Expand All @@ -61,13 +61,16 @@ public NativeSM3HMac(long pointer) {
public void update(byte[] data) {
Objects.requireNonNull(data);

if (nativeCrypto().sm3hmacUpdate(pointer, data) != OPENSSL_SUCCESS) {
if (pointer == 0
|| nativeCrypto().sm3hmacUpdate(pointer, data) != OPENSSL_SUCCESS) {
throw new IllegalStateException("SM3Hmac update operation failed");
}
}

public byte[] doFinal() {
byte[] result = nativeCrypto().sm3hmacFinal(pointer);
byte[] result = pointer == 0
? null
: nativeCrypto().sm3hmacFinal(pointer);
if (result == null) {
throw new IllegalStateException("SM3Hmac final operation failed");
}
Expand All @@ -81,20 +84,27 @@ public byte[] doFinal(byte[] data) {

@Override
public void close() {
nativeCrypto().sm3hmacFreeCtx(pointer);
super.close();
if (pointer != 0) {
nativeCrypto().sm3hmacFreeCtx(pointer);
super.close();
}
}

public void reset() {
if (nativeCrypto().sm3hmacReset(pointer) != OPENSSL_SUCCESS) {
if (pointer == 0
|| nativeCrypto().sm3hmacReset(pointer) != OPENSSL_SUCCESS) {
throw new IllegalStateException("SM3Hmac reset operation failed");
}
}

@Override
protected NativeSM3HMac clone() {
if (pointer == 0) {
throw new IllegalStateException("Cannot clone closed SM3Hmac instance");
}

long clonePointer = nativeCrypto().sm3hmacClone(pointer);
if (clonePointer <= 0) {
if (clonePointer == 0) {
throw new IllegalStateException("SM3Hmac clone operation failed");
}
return new NativeSM3HMac(clonePointer);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -76,15 +76,19 @@ private static long createCtx(boolean encrypt, Mode mode, boolean padding,
byte[] update(byte[] data) {
Objects.requireNonNull(data);

byte[] result = nativeCrypto().sm4Update(pointer, data);
byte[] result = pointer == 0
? null
: nativeCrypto().sm4Update(pointer, data);
if (result == null) {
throw new IllegalStateException("SM4 update operation failed");
}
return result;
}

byte[] doFinal() {
byte[] result = nativeCrypto().sm4Final(pointer);
byte[] result = pointer == 0
? null
: nativeCrypto().sm4Final(pointer);
if (result == null) {
throw new IllegalStateException("SM4 final operation failed");
}
Expand All @@ -104,8 +108,10 @@ byte[] doFinal(byte[] data) {

@Override
public void close() {
nativeCrypto().sm4FreeCtx(pointer);
super.close();
if (pointer != 0) {
nativeCrypto().sm4FreeCtx(pointer);
super.close();
}
}

final static class SM4CBC extends NativeSM4 {
Expand Down Expand Up @@ -161,7 +167,11 @@ static class SM4GCM extends NativeSM4 {
void updateAAD(byte[] aad) {
Objects.requireNonNull(aad);

nativeCrypto().sm4GCMUpdateAAD(pointer, aad);
if (pointer != 0) {
nativeCrypto().sm4GCMUpdateAAD(pointer, aad);
} else {
throw new IllegalStateException("SM4 updateAAD operation failed");
}
}

@Override
Expand All @@ -176,7 +186,8 @@ byte[] doFinal() {

byte[] getTag() {
byte[] tag = new byte[SM4_GCM_TAG_LEN];
if (nativeCrypto().sm4GCMProcTag(pointer, tag) != OPENSSL_SUCCESS) {
if (pointer == 0
|| nativeCrypto().sm4GCMProcTag(pointer, tag) != OPENSSL_SUCCESS) {
throw new IllegalStateException("SM4GCM getTag operation failed");
}
return tag;
Expand All @@ -187,7 +198,8 @@ void setTag(byte[] tag) {
throw new IllegalArgumentException("Tag must be 16-bytes");
}

if (nativeCrypto().sm4GCMProcTag(pointer, tag) != OPENSSL_SUCCESS) {
if (pointer == 0
|| nativeCrypto().sm4GCMProcTag(pointer, tag) != OPENSSL_SUCCESS) {
throw new IllegalStateException("SM4GCM setTag operation failed");
}
}
Expand Down

0 comments on commit 9a45a24

Please sign in to comment.