Skip to content

Commit

Permalink
Merge pull request #621 from 10up/enhancement/embeddings
Browse files Browse the repository at this point in the history
Enhancement/577 Threshold values for Embeddings
  • Loading branch information
dkotter committed Dec 7, 2023
2 parents 1e4d2a0 + b888fb4 commit a45e54a
Show file tree
Hide file tree
Showing 6 changed files with 147 additions and 10 deletions.
20 changes: 17 additions & 3 deletions includes/Classifai/Providers/OpenAI/EmbeddingCalculations.php
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,11 @@ class EmbeddingCalculations {
*
* @param array $source_embedding Embedding data of the source item.
* @param array $compare_embedding Embedding data of the item to compare.
* @param float $threshold The threshold to use for the similarity calculation.
*
* @return bool|float
*/
public function similarity( array $source_embedding = [], array $compare_embedding = [] ) {
public function similarity( array $source_embedding = [], array $compare_embedding = [], $threshold = 1 ) {
if ( empty( $source_embedding ) || empty( $compare_embedding ) ) {
return false;
}
Expand Down Expand Up @@ -56,8 +58,20 @@ function( $x ) {
// Do the math.
$distance = 1.0 - ( $combined_average / sqrt( $source_average * $compare_average ) );

// Ensure we are within the range of 0 to 2.0.
return max( 0, min( abs( (float) $distance ), 2.0 ) );
/**
* Filter the threshold for the similarity calculation.
*
* @since 2.5.0
* @hook classifai_threshold
*
* @param {float} $threshold The threshold to use.
*
* @return {float} The threshold to use.
*/
$threshold = apply_filters( 'classifai_threshold', $threshold );

// Ensure we are within the range of 0 to 1.0 (i.e. $threshold).
return max( 0, min( abs( (float) $distance ), $threshold ) );
}

}
71 changes: 69 additions & 2 deletions includes/Classifai/Providers/OpenAI/Embeddings.php
Original file line number Diff line number Diff line change
Expand Up @@ -80,13 +80,35 @@ public function register() {
add_action( 'wp_insert_post', [ $this, 'generate_embeddings_for_post' ] );
add_action( 'created_term', [ $this, 'generate_embeddings_for_term' ] );
add_action( 'edited_terms', [ $this, 'generate_embeddings_for_term' ] );
add_action( 'admin_enqueue_scripts', [ $this, 'enqueue_admin_assets' ] );
add_action( 'enqueue_block_editor_assets', [ $this, 'enqueue_editor_assets' ], 9 );
add_filter( 'rest_api_init', [ $this, 'add_process_content_meta_to_rest_api' ] );
add_action( 'add_meta_boxes', [ $this, 'add_metabox' ] );
add_action( 'save_post', [ $this, 'save_metabox' ] );
}
}

/**
* Enqueue the admin scripts.
*/
public function enqueue_admin_assets() {
wp_enqueue_script(
'classifai-language-processing-script',
CLASSIFAI_PLUGIN_URL . 'dist/language-processing.js',
get_asset_info( 'language-processing', 'dependencies' ),
get_asset_info( 'language-processing', 'version' ),
true
);

wp_enqueue_style(
'classifai-language-processing-style',
CLASSIFAI_PLUGIN_URL . 'dist/language-processing.css',
array(),
get_asset_info( 'language-processing', 'version' ),
'all'
);
}

/**
* Enqueue editor assets.
*/
Expand Down Expand Up @@ -268,6 +290,15 @@ public function sanitize_settings( $settings ) {
} else {
$new_settings['taxonomies'][ $taxonomy_key ] = '0';
}

// Sanitize the threshold setting.
$taxonomy_key = $taxonomy_key . '_threshold';
if ( isset( $settings['taxonomies'][ $taxonomy_key ] ) && '0' !== $settings['taxonomies'][ $taxonomy_key ] ) {
$threshold_value = min( absint( $settings['taxonomies'][ $taxonomy_key ] ), 100 );
$new_settings['taxonomies'][ $taxonomy_key ] = $threshold_value ? $threshold_value : 75;
} else {
$new_settings['taxonomies'][ $taxonomy_key ] = 75;
}
}

// Sanitize the number setting.
Expand Down Expand Up @@ -354,6 +385,39 @@ public function supported_post_types() {
return apply_filters( 'classifai_openai_embeddings_post_types', $this->get_supported_post_types() );
}

/**
* Get the threshold for the similarity calculation.
*
* @since 2.5.0
*
* @param string $taxonomy Taxonomy slug.
* @return float
*/
public function get_threshold( $taxonomy = '' ) {
$settings = $this->get_settings();
$threshold = 1;

if ( ! empty( $taxonomy ) ) {
$threshold = isset( $settings['taxonomies'][ $taxonomy . '_threshold' ] ) ? $settings['taxonomies'][ $taxonomy . '_threshold' ] : 75;
}

// Convert $threshold (%) to decimal.
$threshold = 1 - ( (float) $threshold / 100 );

/**
* Filter the threshold for the similarity calculation.
*
* @since 2.5.0
* @hook classifai_threshold
*
* @param {float} $threshold The threshold to use.
* @param {string} $taxonomy The taxonomy to get the threshold for.
*
* @return {float} The threshold to use.
*/
return apply_filters( 'classifai_threshold', $threshold, $taxonomy );
}

/**
* The list of supported post statuses.
*
Expand Down Expand Up @@ -468,6 +532,9 @@ private function set_terms( int $post_id = 0, array $embedding = [] ) {
continue;
}

// Get threshold setting for this taxonomy.
$threshold = $this->get_threshold( $tax );

// Get embedding similarity for each term.
foreach ( $terms as $term_id ) {
if ( ! current_user_can( 'assign_term', $term_id ) && ( ! defined( 'WP_CLI' ) || ! WP_CLI ) ) {
Expand All @@ -477,9 +544,9 @@ private function set_terms( int $post_id = 0, array $embedding = [] ) {
$term_embedding = get_term_meta( $term_id, 'classifai_openai_embeddings', true );

if ( $term_embedding ) {
$similarity = $calculations->similarity( $embedding, $term_embedding );
$similarity = $calculations->similarity( $embedding, $term_embedding, $threshold );
if ( false !== $similarity ) {
$embedding_similarity[ $tax ][ $term_id ] = $calculations->similarity( $embedding, $term_embedding );
$embedding_similarity[ $tax ][ $term_id ] = $similarity;
}
}
}
Expand Down
45 changes: 40 additions & 5 deletions includes/Classifai/Providers/Provider.php
Original file line number Diff line number Diff line change
Expand Up @@ -402,12 +402,14 @@ public function render_checkbox_group( array $args = array() ) {

// Iterate through all of our options.
foreach ( $args['options'] as $option_value => $option_label ) {
$value = '';
$default_key = array_search( $option_value, $args['default_values'], true );
$value = '';
$default_key = array_search( $option_value, $args['default_values'], true );
$option_value_theshold = $option_value . '_threshold';

// Get saved value, if any.
if ( isset( $setting_index[ $args['label_for'] ] ) ) {
$value = $setting_index[ $args['label_for'] ][ $option_value ] ?? '';
$value = $setting_index[ $args['label_for'] ][ $option_value ] ?? '';
$threshold_value = $setting_index[ $args['label_for'] ][ $option_value_theshold ] ?? '';
}

// Check for backward compatibility.
Expand All @@ -430,11 +432,16 @@ public function render_checkbox_group( array $args = array() ) {
</label>
</p>',
esc_attr( $this->option_name ),
esc_attr( $args['label_for'] ),
esc_attr( $args['label_for'] ?? '' ),
esc_attr( $option_value ),
checked( $value, $option_value, false ),
esc_html( $option_label )
);

// Render Threshold field.
if ( 'openai_embeddings' === $this->option_name && 'taxonomies' === $args['label_for'] ) {
$this->render_threshold_field( $args, $option_value_theshold, $threshold_value );
}
}

// Render description, if any.
Expand All @@ -446,6 +453,34 @@ public function render_checkbox_group( array $args = array() ) {
}
}

/**
* Render a threshold field.
*
* @since 2.5.0
*
* @param array $args The args passed to add_settings_field
* @param string $option_value The option value.
* @param string $value The value.
*
* @return void
*/
public function render_threshold_field( $args, $option_value, $value ) {
printf(
'<p class="threshold_wrapper">
<label for="%1$s_%2$s_%3$s">%4$s</label>
<br>
<input type="number" id="%1$s_%2$s_%3$s" class="small-text" name="classifai_%1$s[%2$s][%3$s]" value="%5$s" />
</p>',
esc_attr( $this->option_name ),
esc_attr( $args['label_for'] ?? '' ),
esc_attr( $option_value ),
esc_html__( 'Threshold (%)', 'classifai' ),
$value ? esc_attr( $value ) : 75
);
}



/**
* Renders the checkbox group for 'Generate descriptive text' setting.
*
Expand Down Expand Up @@ -490,7 +525,7 @@ public function render_auto_caption_fields( $args ) {
</label>
</p>',
esc_attr( $this->option_name ),
esc_attr( $args['label_for'] ),
esc_attr( $args['label_for'] ?? '' ),
esc_attr( $option_value ),
checked( $default_value, $option_value, false ),
esc_html( $option_label )
Expand Down
3 changes: 3 additions & 0 deletions phpcs.xml.dist
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,9 @@
<exclude name="Squiz.Commenting.FileComment.MissingPackageTag" />
<!-- Class properties don't need this -->
<exclude name="Generic.Commenting.DocComment.MissingShort" />

<!-- For CI, also fail on warnings -->
<config name="ignore_warnings_on_exit" value="0"/>
</rule>

<exclude-pattern>*/tests/*</exclude-pattern>
Expand Down
17 changes: 17 additions & 0 deletions src/scss/language-processing.scss
Original file line number Diff line number Diff line change
Expand Up @@ -171,3 +171,20 @@
color: rgb(117, 117, 117);
}
}

.form-table {
td {
p.threshold_wrapper {
margin: 10px 0 40px;

&:last-of-type {
margin-bottom: 20px;
}

label {
display: inline-block;
margin: 0 0 10px;
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ describe( '[Language processing] Classify Content (OpenAI) Tests', () => {
cy.get( '#openai_embeddings_post_types_post' ).check();
cy.get( '#openai_embeddings_post_statuses_publish' ).check();
cy.get( '#openai_embeddings_taxonomies_category' ).check();
cy.get( '#openai_embeddings_taxonomies_category_threshold' ).type( 80 ); // "Test" requires 80% confidence. At 81%, it does not apply.
cy.get( '#number' ).clear().type( 1 );
cy.get( '#submit' ).click();
} );
Expand Down

0 comments on commit a45e54a

Please sign in to comment.