Skip to content

Commit

Permalink
add support for OpenAIEmbedding
Browse files Browse the repository at this point in the history
  • Loading branch information
Sidsector9 committed Aug 30, 2024
1 parent 8f1d26a commit 8ffd0f0
Show file tree
Hide file tree
Showing 3 changed files with 37 additions and 11 deletions.
6 changes: 5 additions & 1 deletion includes/Classifai/Providers/Azure/Embeddings.php
Original file line number Diff line number Diff line change
Expand Up @@ -358,9 +358,13 @@ public function get_post_classifier_embeddings_preview_data(): array {
// Add terms to this item based on embedding data.
if ( $embeddings && ! is_wp_error( $embeddings ) ) {
$embeddings_terms = $this->get_terms( $embeddings );

if ( is_wp_error( $embeddings_terms ) ) {
wp_send_json_error( $embeddings_terms->get_error_message() );
}
}

return wp_send_json_success( $embeddings_terms );
wp_send_json_success( $embeddings_terms );
}

/**
Expand Down
20 changes: 11 additions & 9 deletions includes/Classifai/Providers/OpenAI/Embeddings.php
Original file line number Diff line number Diff line change
Expand Up @@ -470,9 +470,13 @@ public function get_post_classifier_embeddings_preview_data(): array {
// Add terms to this item based on embedding data.
if ( $embeddings && ! is_wp_error( $embeddings ) ) {
$embeddings_terms = $this->get_terms( $embeddings );

if ( is_wp_error( $embeddings_terms ) ) {
wp_send_json_error( $embeddings_terms->get_error_message() );
}
}

return wp_send_json_success( $embeddings_terms );
wp_send_json_success( $embeddings_terms );
}

/**
Expand Down Expand Up @@ -683,31 +687,29 @@ function ( $a, $b ) {
}

// Prepare the results.
$index = 0;
$results = [];

foreach ( $sorted_results as $tax => $terms ) {
// Get the taxonomy name.
$taxonomy = get_taxonomy( $tax );
$tax_name = $taxonomy->labels->singular_name;

// Setup our taxonomy object.
$results[] = new \stdClass();

$results[ $index ]->{$tax_name} = [];
// Initialize the taxonomy bucket in results.
$results[ $tax ] = [
'label' => $tax_name,
'data' => []
];

foreach ( $terms as $term ) {
// Convert $similarity to percentage.
$similarity = round( ( 1 - $term['similarity'] ), 10 );

// Store the results.
$results[ $index ]->{$tax_name}[] = [ // phpcs:ignore Squiz.PHP.DisallowMultipleAssignments.Found
$results[ $tax ]['data'][] = [
'label' => get_term( $term['term_id'] )->name,
'score' => $similarity,
];
}

++$index;
}

return $results;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -338,27 +338,37 @@ function PreviewerResults() {

return (
<div className='classifai__classification-previewer-search-result-container'>
{ 'azure_openai_embeddings' === activeProvider && <AzureOpenAIEmbeddingsResults postId={ selectedPostId } /> }
{ 'azure_openai_embeddings' === activeProvider || 'openai_embeddings' === activeProvider && <AzureOpenAIEmbeddingsResults postId={ selectedPostId } /> }
</div>
);
}

function AzureOpenAIEmbeddingsResults( { postId } ) {
const {
isPreviewUnderProcess,
setPreviewUnderProcess,
setIsPreviewerOpen,
} = useContext( PreviewerProviderContext );

const [ responseData, setResponseData ] = useState( [] );
const [ errorMessage, setErrorMessage ] = useState( '' );
const settings = useSelect( ( select ) => select( STORE_NAME ).getFeatureSettings() );

useEffect( () => {
// Reset previous results.
if ( isPreviewUnderProcess ) {
setResponseData( [] );
}
}, [ isPreviewUnderProcess ] );

useEffect( () => {
if ( ! postId ) {
return;
}

setPreviewUnderProcess( true );
setIsPreviewerOpen( true );
setErrorMessage( '' );

const formData = new FormData();

Expand All @@ -384,6 +394,8 @@ function AzureOpenAIEmbeddingsResults( { postId } ) {

if ( responseJSON.success ) {
setResponseData( responseJSON.data );
} else {
setErrorMessage( responseJSON.data );
}

setPreviewUnderProcess( false );
Expand Down Expand Up @@ -419,6 +431,14 @@ function AzureOpenAIEmbeddingsResults( { postId } ) {
)
} );

if ( errorMessage ) {
return (
<Notice status='error' isDismissible={ false } className='classifai__classification-previewer-result-notice'>
{ errorMessage }
</Notice>
);
}

return card.length ? (
<>
<Notice status='success' isDismissible={ false } className='classifai__classification-previewer-result-notice'>
Expand Down

0 comments on commit 8ffd0f0

Please sign in to comment.