Skip to content

Commit

Permalink
Merge pull request #622 from 10up/enhancement/591
Browse files Browse the repository at this point in the history
Enhancement/591 Previewing of OpenAI Embeddings
  • Loading branch information
dkotter authored Dec 12, 2023
2 parents c8e2fbd + faff524 commit 7e6f0b7
Show file tree
Hide file tree
Showing 7 changed files with 474 additions and 218 deletions.
10 changes: 8 additions & 2 deletions includes/Classifai/Admin/PreviewClassifierData.php
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ public function __construct() {
public function get_post_classifier_preview_data() {
$nonce = isset( $_POST['nonce'] ) ? sanitize_text_field( wp_unslash( $_POST['nonce'] ) ) : false;

if ( ! $nonce || ! wp_verify_nonce( $nonce, 'classifai-previewer-action' ) ) {
if ( ! $nonce || ! wp_verify_nonce( $nonce, 'classifai-previewer-watson_nlu-action' ) ) {
wp_send_json_error( esc_html__( 'Failed nonce check.', 'classifai' ) );
}

Expand All @@ -45,7 +45,13 @@ public function get_post_classifier_preview_data() {
public function get_post_search_results() {
$nonce = isset( $_POST['nonce'] ) ? sanitize_text_field( wp_unslash( $_POST['nonce'] ) ) : false;

if ( ! $nonce || ! wp_verify_nonce( $nonce, 'classifai-previewer-action' ) ) {
if (
! $nonce
|| (
! wp_verify_nonce( $nonce, 'classifai-previewer-openai_embeddings-action' )
&& ! wp_verify_nonce( $nonce, 'classifai-previewer-watson_nlu-nonce' )
)
) {
wp_send_json_error( esc_html__( 'Failed nonce check.', 'classifai' ) );
}

Expand Down
19 changes: 3 additions & 16 deletions includes/Classifai/Providers/OpenAI/EmbeddingCalculations.php
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,10 @@ class EmbeddingCalculations {
*
* @param array $source_embedding Embedding data of the source item.
* @param array $compare_embedding Embedding data of the item to compare.
* @param float $threshold The threshold to use for the similarity calculation.
*
* @return bool|float
*/
public function similarity( array $source_embedding = [], array $compare_embedding = [], $threshold = 1 ) {
public function similarity( array $source_embedding = [], array $compare_embedding = [] ) {
if ( empty( $source_embedding ) || empty( $compare_embedding ) ) {
return false;
}
Expand Down Expand Up @@ -58,20 +57,8 @@ function( $x ) {
// Do the math.
$distance = 1.0 - ( $combined_average / sqrt( $source_average * $compare_average ) );

/**
* Filter the threshold for the similarity calculation.
*
* @since 2.5.0
* @hook classifai_threshold
*
* @param {float} $threshold The threshold to use.
*
* @return {float} The threshold to use.
*/
$threshold = apply_filters( 'classifai_threshold', $threshold );

// Ensure we are within the range of 0 to 1.0 (i.e. $threshold).
return max( 0, min( abs( (float) $distance ), $threshold ) );
// Ensure we are within the range of 0 to 1.0.
return max( 0, min( abs( (float) $distance ), 1.0 ) );
}

}
163 changes: 138 additions & 25 deletions includes/Classifai/Providers/OpenAI/Embeddings.php
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,7 @@ public function register() {
add_filter( 'rest_api_init', [ $this, 'add_process_content_meta_to_rest_api' ] );
add_action( 'add_meta_boxes', [ $this, 'add_metabox' ] );
add_action( 'save_post', [ $this, 'save_metabox' ] );
add_action( 'wp_ajax_get_post_classifier_embeddings_preview_data', array( $this, 'get_post_classifier_embeddings_preview_data' ) );
}
}

Expand Down Expand Up @@ -456,12 +457,36 @@ public function supported_taxonomies() {
return apply_filters( 'classifai_openai_embeddings_taxonomies', $this->get_supported_taxonomies() );
}

/**
* Get the data to preview terms.
*
* @since 2.5.0
*
* @return array
*/
public function get_post_classifier_embeddings_preview_data() {
$nonce = isset( $_POST['nonce'] ) ? sanitize_text_field( wp_unslash( $_POST['nonce'] ) ) : false;

if ( ! $nonce || ! wp_verify_nonce( $nonce, 'classifai-previewer-openai_embeddings-action' ) ) {
wp_send_json_error( esc_html__( 'Failed nonce check.', 'classifai' ) );
}

$post_id = filter_input( INPUT_POST, 'post_id', FILTER_SANITIZE_NUMBER_INT );

$embeddings_terms = $this->generate_embeddings_for_post( $post_id, true );

return wp_send_json_success( $embeddings_terms );
}

/**
* Trigger embedding generation for content being saved.
*
* @param int $post_id ID of post being saved.
* @param int $post_id ID of post being saved.
* @param bool $dryrun Whether to run the process or just return the data.
*
* @return array|WP_Error
*/
public function generate_embeddings_for_post( $post_id ) {
public function generate_embeddings_for_post( $post_id, $dryrun = false ) {
// Don't run on autosaves.
if ( defined( 'DOING_AUTOSAVE' ) && DOING_AUTOSAVE ) {
return;
Expand All @@ -476,23 +501,30 @@ public function generate_embeddings_for_post( $post_id ) {

// Only run on supported post types and statuses.
if (
! in_array( $post->post_type, $this->supported_post_types(), true ) ||
! in_array( $post->post_status, $this->supported_post_statuses(), true )
! $dryrun
&& (
! in_array( $post->post_type, $this->supported_post_types(), true ) ||
! in_array( $post->post_status, $this->supported_post_statuses(), true )
)
) {
return;
}

// Don't run if turned off for this particular post.
if ( 'no' === get_post_meta( $post_id, '_classifai_process_content', true ) ) {
if ( 'no' === get_post_meta( $post_id, '_classifai_process_content', true ) && ! $dryrun ) {
return;
}

$embeddings = $this->generate_embeddings( $post_id, 'post' );

// Add terms to this item based on embedding data.
if ( $embeddings && ! is_wp_error( $embeddings ) ) {
update_post_meta( $post_id, 'classifai_openai_embeddings', array_map( 'sanitize_text_field', $embeddings ) );
$this->set_terms( $post_id, $embeddings );
if ( $dryrun ) {
return $this->get_terms( $embeddings );
} else {
update_post_meta( $post_id, 'classifai_openai_embeddings', array_map( 'sanitize_text_field', $embeddings ) );
return $this->set_terms( $post_id, $embeddings );
}
}
}

Expand All @@ -513,6 +545,102 @@ private function set_terms( int $post_id = 0, array $embedding = [] ) {

$settings = $this->get_settings();
$number_to_add = $settings['number'] ?? 1;
$embedding_similarity = $this->get_embeddings_similarity( $embedding );

if ( empty( $embedding_similarity ) ) {
return;
}

// Set terms based on similarity.
foreach ( $embedding_similarity as $tax => $terms ) {
// Sort embeddings from lowest to highest.
asort( $terms );

// Only add the number of terms specified in settings.
if ( count( $terms ) > $number_to_add ) {
$terms = array_slice( $terms, 0, $number_to_add, true );
}

wp_set_object_terms( $post_id, array_map( 'absint', array_keys( $terms ) ), $tax, false );
}
}

/**
* Get the terms of a post based on embeddings.
*
* @param array $embedding Embedding data.
*
* @return array|WP_Error
*/
private function get_terms( array $embedding = [] ) {
if ( empty( $embedding ) ) {
return new WP_Error( 'data_required', esc_html__( 'Valid embedding data is required to get terms.', 'classifai' ) );
}

$settings = $this->get_settings();
$number_to_add = $settings['number'] ?? 1;
$embedding_similarity = $this->get_embeddings_similarity( $embedding, false );

if ( empty( $embedding_similarity ) ) {
return;
}

// Set terms based on similarity.
$index = 0;
$result = [];

foreach ( $embedding_similarity as $tax => $terms ) {
// Get the taxonomy name.
$taxonomy = get_taxonomy( $tax );
$tax_name = $taxonomy->labels->singular_name;

// Sort embeddings from lowest to highest.
asort( $terms );

// Return the terms.
$result[ $index ] = new \stdClass();

$result[ $index ]->{$tax_name} = [];

$term_added = 0;
foreach ( $terms as $term_id => $similarity ) {
// Stop if we have added the number of terms specified in settings.
if ( $number_to_add <= $term_added ) {
break;
}

// Convert $similarity to percentage.
$similarity = round( ( 1 - $similarity ), 10 );

$result[ $index ]->{$tax_name}[] = [// phpcs:ignore Squiz.PHP.DisallowMultipleAssignments.Found
'label' => get_term( $term_id )->name,
'score' => $similarity,
];
$term_added++;
}

// Only add the number of terms specified in settings.
if ( count( $terms ) > $number_to_add ) {
$terms = array_slice( $terms, 0, $number_to_add, true );
}

$index++;
}

return $result;
}

/**
* Get the similarity between an embedding and all terms.
*
* @since 2.5.0
*
* @param array $embedding Embedding data.
* @param bool $consider_threshold Whether to consider the threshold setting.
*
* @return array
*/
private function get_embeddings_similarity( $embedding, $consider_threshold = true ) {
$embedding_similarity = [];
$taxonomies = $this->supported_taxonomies();
$calculations = new EmbeddingCalculations();
Expand Down Expand Up @@ -544,30 +672,15 @@ private function set_terms( int $post_id = 0, array $embedding = [] ) {
$term_embedding = get_term_meta( $term_id, 'classifai_openai_embeddings', true );

if ( $term_embedding ) {
$similarity = $calculations->similarity( $embedding, $term_embedding, $threshold );
if ( false !== $similarity ) {
$similarity = $calculations->similarity( $embedding, $term_embedding );
if ( false !== $similarity && ( ! $consider_threshold || $similarity <= $threshold ) ) {
$embedding_similarity[ $tax ][ $term_id ] = $similarity;
}
}
}
}

if ( empty( $embedding_similarity ) ) {
return;
}

// Set terms based on similarity.
foreach ( $embedding_similarity as $tax => $terms ) {
// Sort embeddings from lowest to highest.
asort( $terms );

// Only add the number of terms specified in settings.
if ( count( $terms ) > $number_to_add ) {
$terms = array_slice( $terms, 0, $number_to_add, true );
}

wp_set_object_terms( $post_id, array_map( 'absint', array_keys( $terms ) ), $tax, false );
}
return $embedding_similarity;
}

/**
Expand Down
35 changes: 25 additions & 10 deletions includes/Classifai/Services/Service.php
Original file line number Diff line number Diff line change
Expand Up @@ -151,8 +151,17 @@ public function render_settings_page() {
<?php
// Find the right provider class.
$provider = find_provider_class( $this->provider_classes ?? [], 'Natural Language Understanding' );

if ( ! is_wp_error( $provider ) && ! empty( $provider->can_register() ) && $provider->is_feature_enabled( 'content_classification' ) ) :
if ( 'openai_embeddings' === $active_tab ) {
$provider = find_provider_class( $this->provider_classes ?? [], 'Embeddings' );
}

if (
! is_wp_error( $provider )
&& ! empty(
$provider->can_register()
&& ( $provider->is_feature_enabled( 'content_classification' ) || $provider->is_feature_enabled( 'classification' ) )
)
) :
?>
<div id="classifai-post-preview-app">
<?php
Expand Down Expand Up @@ -191,26 +200,32 @@ public function render_settings_page() {
);
?>

<?php if ( 'watson_nlu' === $active_tab ) : ?>
<?php if ( 'watson_nlu' === $active_tab || 'openai_embeddings' === $active_tab ) : ?>
<h2><?php esc_html_e( 'Preview Language Processing', 'classifai' ); ?></h2>
<div id="classifai-post-preview-controls">
<select id="classifai-preview-post-selector">
<?php foreach ( $posts_to_preview as $post ) : ?>
<option value="<?php echo esc_attr( $post->ID ); ?>"><?php echo esc_html( $post->post_title ); ?></option>
<?php endforeach; ?>
</select>
<?php wp_nonce_field( 'classifai-previewer-action', 'classifai-previewer-nonce' ); ?>
<?php wp_nonce_field( "classifai-previewer-$active_tab-action", "classifai-previewer-$active_tab-nonce" ); ?>
<button type="button" class="button" id="get-classifier-preview-data-btn">
<span><?php esc_html_e( 'Preview', 'classifai' ); ?></span>
</button>
</div>
<div id="classifai-post-preview-wrapper">
<?php foreach ( $features as $feature_slug => $feature ) : ?>
<div class="tax-row tax-row--<?php echo esc_attr( $feature['plural'] ); ?> <?php echo esc_attr( $feature['enabled'] ) ? '' : 'tax-row--hide'; ?>">
<div class="tax-type"><?php echo esc_html( $feature['name'] ); ?></div>
</div>
<?php endforeach; ?>
</div>
<?php
if ( 'watson_nlu' === $active_tab ) :
foreach ( $features as $feature_slug => $feature ) :
?>
<div class="tax-row tax-row--<?php echo esc_attr( $feature['plural'] ); ?> <?php echo esc_attr( $feature['enabled'] ) ? '' : 'tax-row--hide'; ?>">
<div class="tax-type"><?php echo esc_html( $feature['name'] ); ?></div>
</div>
<?php
endforeach;
endif;
?>
</div>
<?php endif; ?>
</div>
<?php endif; ?>
Expand Down
Loading

0 comments on commit 7e6f0b7

Please sign in to comment.