Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix buffer leak and jni bytes object leak for jni command/write #28913

Merged
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
192 changes: 93 additions & 99 deletions src/controller/java/CHIPDeviceController-JNI.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1865,6 +1865,36 @@ JNI_METHOD(void, read)
}
}

// Convert Json to Tlv, and remove the outer structure
CHIP_ERROR ConvertJsonToTlvWithoutStruct(const std::string & json, MutableByteSpan & data)
{
Platform::ScopedMemoryBufferWithSize<uint8_t> buf;
VerifyOrReturnError(buf.Calloc(data.size()), CHIP_ERROR_NO_MEMORY);
MutableByteSpan dataWithStruct(buf.Get(), buf.AllocatedSize());
ReturnErrorOnFailure(JsonToTlv(json, dataWithStruct));
TLV::TLVReader tlvReader;
TLV::TLVType outerContainer = TLV::kTLVType_Structure;
tlvReader.Init(dataWithStruct);
ReturnErrorOnFailure(tlvReader.Next(TLV::kTLVType_Structure, TLV::AnonymousTag()));
ReturnErrorOnFailure(tlvReader.EnterContainer(outerContainer));
ReturnErrorOnFailure(tlvReader.Next());

TLV::TLVWriter tlvWrite;
tlvWrite.Init(data);
ReturnErrorOnFailure(tlvWrite.CopyElement(TLV::AnonymousTag(), tlvReader));
ReturnErrorOnFailure(tlvWrite.Finalize());
data.reduce_size(tlvWrite.GetLengthWritten());
return CHIP_NO_ERROR;
}

CHIP_ERROR PutPreencodedWriteAttribute(app::WriteClient & writeClient, app::ConcreteDataAttributePath & path, const ByteSpan & data)
{
TLV::TLVReader reader;
reader.Init(data);
ReturnErrorOnFailure(reader.Next());
return writeClient.PutPreencodedAttribute(path, reader);
}

JNI_METHOD(void, write)
(JNIEnv * env, jobject self, jlong handle, jlong callbackHandle, jlong devicePtr, jobject attributeList, jint timedRequestTimeoutMs,
jint imTimeoutMs)
Expand All @@ -1875,8 +1905,6 @@ JNI_METHOD(void, write)
auto callback = reinterpret_cast<WriteAttributesCallback *>(callbackHandle);
app::WriteClient * writeClient = nullptr;
uint16_t convertedTimedRequestTimeoutMs = static_cast<uint16_t>(timedRequestTimeoutMs);
bool hasValidTlv = false;
bool hasValidJson = false;

ChipLogDetail(Controller, "IM write() called");

Expand Down Expand Up @@ -1909,9 +1937,6 @@ JNI_METHOD(void, write)
jbyteArray tlvBytesObj = nullptr;
bool hasDataVersion = false;
Optional<DataVersion> dataVersion = Optional<DataVersion>();
uint8_t * tlvBytes = nullptr;
size_t length = 0;
TLV::TLVReader reader;

SuccessOrExit(err = JniReferences::GetInstance().GetListItem(attributeList, i, attributeItem));
SuccessOrExit(err = JniReferences::GetInstance().FindMethod(
Expand Down Expand Up @@ -1955,53 +1980,34 @@ JNI_METHOD(void, write)

tlvBytesObj = static_cast<jbyteArray>(env->CallObjectMethod(attributeItem, getTlvByteArrayMethod));
VerifyOrExit(!env->ExceptionCheck(), err = CHIP_JNI_ERROR_EXCEPTION_THROWN);
app::ConcreteDataAttributePath path(static_cast<EndpointId>(endpointId), static_cast<ClusterId>(clusterId),
static_cast<AttributeId>(attributeId), dataVersion);
if (tlvBytesObj != nullptr)
{
jbyte * tlvBytesObjBytes = env->GetByteArrayElements(tlvBytesObj, nullptr);
VerifyOrExit(!env->ExceptionCheck(), err = CHIP_JNI_ERROR_EXCEPTION_THROWN);
length = static_cast<size_t>(env->GetArrayLength(tlvBytesObj));
VerifyOrExit(!env->ExceptionCheck(), err = CHIP_JNI_ERROR_EXCEPTION_THROWN);
tlvBytes = reinterpret_cast<uint8_t *>(tlvBytesObjBytes);
hasValidTlv = true;
JniByteArray tlvByteArray(env, tlvBytesObj);
SuccessOrExit(err = PutPreencodedWriteAttribute(*writeClient, path, tlvByteArray.byteSpan()));
}
else
{
SuccessOrExit(err = JniReferences::GetInstance().FindMethod(env, attributeItem, "getJsonString", "()Ljava/lang/String;",
&getJsonStringMethod));
jstring jsonJniString = static_cast<jstring>(env->CallObjectMethod(attributeItem, getJsonStringMethod));
VerifyOrExit(!env->ExceptionCheck(), err = CHIP_JNI_ERROR_EXCEPTION_THROWN);
if (jsonJniString != nullptr)
{
JniUtfString jsonUtfJniString(env, jsonJniString);
uint8_t bufWithStruct[chip::app::kMaxSecureSduLengthBytes] = { 0 };
uint8_t buf[chip::app::kMaxSecureSduLengthBytes] = { 0 };
TLV::TLVReader tlvReader;
TLV::TLVWriter tlvWrite;
TLV::TLVType outerContainer = TLV::kTLVType_Structure;
MutableByteSpan dataWithStruct{ bufWithStruct };
MutableByteSpan data{ buf };
SuccessOrExit(err = JsonToTlv(std::string(jsonUtfJniString.c_str(), jsonUtfJniString.size()), dataWithStruct));
tlvReader.Init(dataWithStruct);
SuccessOrExit(err = tlvReader.Next(TLV::kTLVType_Structure, TLV::AnonymousTag()));
SuccessOrExit(err = tlvReader.EnterContainer(outerContainer));
SuccessOrExit(err = tlvReader.Next());
tlvWrite.Init(data);
SuccessOrExit(err = tlvWrite.CopyElement(TLV::AnonymousTag(), tlvReader));
SuccessOrExit(err = tlvWrite.Finalize());
tlvBytes = buf;
length = tlvWrite.GetLengthWritten();
hasValidJson = true;
}
VerifyOrExit(jsonJniString != nullptr, err = CHIP_JNI_ERROR_EXCEPTION_THROWN);
JniUtfString jsonUtfJniString(env, jsonJniString);
std::string jsonString = std::string(jsonUtfJniString.c_str(), jsonUtfJniString.size());

// Context: Chunk write is supported in sdk, oversized list could be chunked in multiple message. When transforming
// JSON to TLV, we need know the actual size for tlv blob when handling JsonToTlv
// TODO: Implement memory auto-grow to get the actual size needed for tlv blob when transforming tlv to json.
// Workaround: Allocate memory using json string's size, which is large enough to hold the corresponding tlv blob
Platform::ScopedMemoryBufferWithSize<uint8_t> tlvBytes;
size_t length = jsonUtfJniString.size();
VerifyOrExit(tlvBytes.Calloc(length), err = CHIP_ERROR_NO_MEMORY);
MutableByteSpan data(tlvBytes.Get(), tlvBytes.AllocatedSize());
SuccessOrExit(err = ConvertJsonToTlvWithoutStruct(jsonString, data));
SuccessOrExit(err = PutPreencodedWriteAttribute(*writeClient, path, data));
}
VerifyOrExit(hasValidTlv || hasValidJson, err = CHIP_ERROR_INVALID_ARGUMENT);

reader.Init(tlvBytes, length);
reader.Next();
SuccessOrExit(
err = writeClient->PutPreencodedAttribute(
chip::app::ConcreteDataAttributePath(static_cast<EndpointId>(endpointId), static_cast<ClusterId>(clusterId),
static_cast<AttributeId>(attributeId), dataVersion),
reader));
}

err = writeClient->SendWriteRequest(device->GetSecureSession().Value(),
Expand Down Expand Up @@ -2030,34 +2036,40 @@ JNI_METHOD(void, write)
}
}

CHIP_ERROR PutPreencodedInvokeRequest(app::CommandSender & commandSender, app::CommandPathParams & path, const ByteSpan & data)
{
// PrepareCommand does nott create the struct container with kFields and copycontainer below sets the
// kFields container already
ReturnErrorOnFailure(commandSender.PrepareCommand(path, false /* aStartDataStruct */));
TLV::TLVWriter * writer = commandSender.GetCommandDataIBTLVWriter();
VerifyOrReturnError(writer != nullptr, CHIP_ERROR_INCORRECT_STATE);
TLV::TLVReader reader;
reader.Init(data);
ReturnErrorOnFailure(reader.Next());
return writer->CopyContainer(TLV::ContextTag(app::CommandDataIB::Tag::kFields), reader);
}

JNI_METHOD(void, invoke)
(JNIEnv * env, jobject self, jlong handle, jlong callbackHandle, jlong devicePtr, jobject invokeElement, jint timedRequestTimeoutMs,
jint imTimeoutMs)
{
chip::DeviceLayer::StackLock lock;
CHIP_ERROR err = CHIP_NO_ERROR;
auto callback = reinterpret_cast<InvokeCallback *>(callbackHandle);
app::CommandSender * commandSender = nullptr;
uint32_t endpointId = 0;
uint32_t clusterId = 0;
uint32_t commandId = 0;
jmethodID getEndpointIdMethod = nullptr;
jmethodID getClusterIdMethod = nullptr;
jmethodID getCommandIdMethod = nullptr;
jmethodID getTlvByteArrayMethod = nullptr;
jmethodID getJsonStringMethod = nullptr;
jobject endpointIdObj = nullptr;
jobject clusterIdObj = nullptr;
jobject commandIdObj = nullptr;
jbyteArray tlvBytesObj = nullptr;
TLV::TLVReader reader;
TLV::TLVWriter * writer = nullptr;
uint8_t * tlvBytes = nullptr;
size_t length = 0;
bool hasValidTlv = false;
bool hasValidJson = false;
CHIP_ERROR err = CHIP_NO_ERROR;
auto callback = reinterpret_cast<InvokeCallback *>(callbackHandle);
app::CommandSender * commandSender = nullptr;
uint32_t endpointId = 0;
uint32_t clusterId = 0;
uint32_t commandId = 0;
jmethodID getEndpointIdMethod = nullptr;
jmethodID getClusterIdMethod = nullptr;
jmethodID getCommandIdMethod = nullptr;
jmethodID getTlvByteArrayMethod = nullptr;
jmethodID getJsonStringMethod = nullptr;
jobject endpointIdObj = nullptr;
jobject clusterIdObj = nullptr;
jobject commandIdObj = nullptr;
jbyteArray tlvBytesObj = nullptr;
uint16_t convertedTimedRequestTimeoutMs = static_cast<uint16_t>(timedRequestTimeoutMs);

ChipLogDetail(Controller, "IM invoke() called");

DeviceProxy * device = reinterpret_cast<DeviceProxy *>(devicePtr);
Expand Down Expand Up @@ -2093,49 +2105,32 @@ JNI_METHOD(void, invoke)

tlvBytesObj = static_cast<jbyteArray>(env->CallObjectMethod(invokeElement, getTlvByteArrayMethod));
VerifyOrExit(!env->ExceptionCheck(), err = CHIP_JNI_ERROR_EXCEPTION_THROWN);
if (tlvBytesObj != nullptr)
{
jbyte * tlvBytesObjBytes = env->GetByteArrayElements(tlvBytesObj, nullptr);
VerifyOrExit(!env->ExceptionCheck(), err = CHIP_JNI_ERROR_EXCEPTION_THROWN);
length = static_cast<size_t>(env->GetArrayLength(tlvBytesObj));
VerifyOrExit(!env->ExceptionCheck(), err = CHIP_JNI_ERROR_EXCEPTION_THROWN);
tlvBytes = reinterpret_cast<uint8_t *>(tlvBytesObjBytes);
hasValidTlv = true;
}
else
{
SuccessOrExit(err = JniReferences::GetInstance().FindMethod(env, invokeElement, "getJsonString", "()Ljava/lang/String;",
&getJsonStringMethod));
jstring jsonJniString = static_cast<jstring>(env->CallObjectMethod(invokeElement, getJsonStringMethod));
VerifyOrExit(!env->ExceptionCheck(), err = CHIP_JNI_ERROR_EXCEPTION_THROWN);
if (jsonJniString != nullptr)
app::CommandPathParams path(static_cast<EndpointId>(endpointId), /* group id */ 0, static_cast<ClusterId>(clusterId),
yunhanw-google marked this conversation as resolved.
Show resolved Hide resolved
static_cast<CommandId>(commandId), app::CommandPathFlags::kEndpointIdValid);
if (tlvBytesObj != nullptr)
{
JniByteArray tlvBytesObjBytes(env, tlvBytesObj);
SuccessOrExit(err = PutPreencodedInvokeRequest(*commandSender, path, tlvBytesObjBytes.byteSpan()));
}
else
{
SuccessOrExit(err = JniReferences::GetInstance().FindMethod(env, invokeElement, "getJsonString", "()Ljava/lang/String;",
&getJsonStringMethod));
jstring jsonJniString = static_cast<jstring>(env->CallObjectMethod(invokeElement, getJsonStringMethod));
VerifyOrExit(!env->ExceptionCheck(), err = CHIP_JNI_ERROR_EXCEPTION_THROWN);
VerifyOrExit(jsonJniString != nullptr, err = CHIP_ERROR_INVALID_ARGUMENT);
JniUtfString jsonUtfJniString(env, jsonJniString);
uint8_t buf[chip::app::kMaxSecureSduLengthBytes] = { 0 };
MutableByteSpan tlvEncodingLocal{ buf };
// The invoke does not support chunk, kMaxSecureSduLengthBytes should be enough for command json blob
uint8_t tlvBytes[chip::app::kMaxSecureSduLengthBytes] = { 0 };
MutableByteSpan tlvEncodingLocal{ tlvBytes };
SuccessOrExit(err = JsonToTlv(std::string(jsonUtfJniString.c_str(), jsonUtfJniString.size()), tlvEncodingLocal));
tlvBytes = tlvEncodingLocal.data();
length = tlvEncodingLocal.size();
hasValidJson = true;
SuccessOrExit(err = PutPreencodedInvokeRequest(*commandSender, path, tlvEncodingLocal));
}
}
VerifyOrExit(hasValidTlv || hasValidJson, err = CHIP_ERROR_INVALID_ARGUMENT);

SuccessOrExit(err = commandSender->PrepareCommand(app::CommandPathParams(static_cast<EndpointId>(endpointId), /* group id */ 0,
static_cast<ClusterId>(clusterId),
static_cast<CommandId>(commandId),
app::CommandPathFlags::kEndpointIdValid),
false));

writer = commandSender->GetCommandDataIBTLVWriter();
VerifyOrExit(writer != nullptr, err = CHIP_ERROR_INCORRECT_STATE);
reader.Init(tlvBytes, static_cast<size_t>(length));
reader.Next();
SuccessOrExit(err = writer->CopyContainer(TLV::ContextTag(app::CommandDataIB::Tag::kFields), reader));
SuccessOrExit(err = commandSender->FinishCommand(convertedTimedRequestTimeoutMs != 0
? Optional<uint16_t>(convertedTimedRequestTimeoutMs)
: Optional<uint16_t>::Missing()));

SuccessOrExit(err =
commandSender->SendCommandRequest(device->GetSecureSession().Value(),
imTimeoutMs != 0 ? MakeOptional(System::Clock::Milliseconds32(imTimeoutMs))
Expand All @@ -2144,7 +2139,6 @@ JNI_METHOD(void, invoke)
callback->mCommandSender = commandSender;

exit:

if (err != CHIP_NO_ERROR)
{
ChipLogError(Controller, "JNI IM Invoke Error: %s", err.AsString());
Expand Down