Skip to content

Commit

Permalink
Use GetConfigValue() to get stable_diffusion_seed and stable_diffusio…
Browse files Browse the repository at this point in the history
…n_num_steps
  • Loading branch information
anhappdev committed Oct 24, 2024
1 parent 73ab660 commit eb9eb2a
Show file tree
Hide file tree
Showing 3 changed files with 16 additions and 8 deletions.
1 change: 1 addition & 0 deletions mobile_back_tflite/cpp/backend_tflite/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@ cc_library(
}),
deps = [
":tflite_settings",
"//flutter/cpp:utils",
"//flutter/cpp/c:headers",
"@org_tensorflow//tensorflow/core:tflite_portable_logging",
"@org_tensorflow//tensorflow/lite/c:c_api",
Expand Down
1 change: 1 addition & 0 deletions mobile_back_tflite/cpp/backend_tflite/neuron/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ cc_library(
local_defines = ["MTK_TFLITE_NEURON_BACKEND"],
deps = [
":tflite_settings",
"//flutter/cpp:utils",
"//flutter/cpp/c:headers",
"//mobile_back_tflite/cpp/backend_tflite:tflite_settings",
"@org_tensorflow//tensorflow/core:tflite_portable_logging",
Expand Down
22 changes: 14 additions & 8 deletions mobile_back_tflite/cpp/backend_tflite/stable_diffusion_pipeline.cc
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
#include <valarray>

#include "flutter/cpp/c/backend_c.h"
#include "flutter/cpp/utils.h"
#include "stable_diffusion_invoker.h"
#include "tensorflow/lite/c/c_api.h"
#include "tensorflow/lite/c/common.h"
Expand Down Expand Up @@ -58,21 +59,26 @@ mlperf_backend_ptr_t StableDiffusionPipeline::backend_create(

// Verify only one instance of the backend exists at any time
if (backendExists) {
LOG(ERROR) << "Backend already exists";
return nullptr;
}

SDBackendData* backend_data = new SDBackendData();
backendExists = true;

for (int i = 0; i < configs->count; ++i) {
if (strcmp(configs->keys[i], "stable_diffusion_seed") == 0) {
backend_data->seed = atoi(configs->values[i]);
}
if (strcmp(configs->keys[i], "stable_diffusion_num_steps") == 0) {
backend_data->num_steps = atoi(configs->values[i]);
}
// Read seed and num_steps value from SD task settings
backend_data->seed =
mlperf::mobile::GetConfigValue(configs, "stable_diffusion_seed", 0);
if (backend_data->seed == 0) {
LOG(ERROR) << "Cannot get stable_diffusion_seed";
return nullptr;
}
backend_data->num_steps =
mlperf::mobile::GetConfigValue(configs, "stable_diffusion_num_steps", 0);
if (backend_data->num_steps == 0) {
LOG(ERROR) << "Cannot get stable_diffusion_num_steps";
return nullptr;
}

// Load models from the provided directory path
std::string text_encoder_path =
std::string(model_path) + "/sd_text_encoder_dynamic.tflite";
Expand Down

0 comments on commit eb9eb2a

Please sign in to comment.