Skip to content

Commit

Permalink
add: audiocraft spectrogram visualization enhancements (#459)
Browse files Browse the repository at this point in the history
* add: dynamic range and better cmap

* add: frequency axis scaling based on mean spectrum threshold

* chore: changed string quotes to "
  • Loading branch information
mratanusarkar authored Sep 6, 2023
1 parent 4e973fa commit 6661b0b
Showing 1 changed file with 22 additions and 38 deletions.
60 changes: 22 additions & 38 deletions colabs/audiocraft/AudioCraft_MusicGen.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,7 @@
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 1000
},
"id": "EZU3hg4B1om6",
"outputId": "1512f321-698c-4d01-e8d9-6d9577eee7e1"
"id": "EZU3hg4B1om6"
},
"outputs": [],
"source": [
Expand Down Expand Up @@ -44,12 +39,7 @@
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 122
},
"id": "3MTX8GoE7AzN",
"outputId": "33e92860-9ff9-4923-87a5-66c2d89e5f22"
"id": "3MTX8GoE7AzN"
},
"outputs": [],
"source": [
Expand Down Expand Up @@ -119,11 +109,7 @@
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "SfM8rhVX7ES9",
"outputId": "22fd7b60-ce23-4aac-86cf-5d5e0a712815"
"id": "SfM8rhVX7ES9"
},
"outputs": [],
"source": [
Expand Down Expand Up @@ -157,42 +143,40 @@
"outputs": [],
"source": [
"def get_spectrogram(audio_file, output_file):\n",
" # Read the audio file\n",
" sample_rate, samples = wavfile.read(audio_file)\n",
"\n",
" # Compute the spectrogram\n",
" frequencies, times, Sxx = signal.spectrogram(samples, sample_rate)\n",
"\n",
" # Create a figure and axis for plotting\n",
" fig, ax = plt.subplots()\n",
" log_Sxx = 10 * np.log10(Sxx + 1e-10)\n",
" vmin = np.percentile(log_Sxx, 5)\n",
" vmax = np.percentile(log_Sxx, 95)\n",
"\n",
" # Plot the spectrogram\n",
" ax.pcolormesh(times, frequencies, 10 * np.log10(Sxx), shading='gouraud')\n",
" mean_spectrum = np.mean(log_Sxx, axis=1)\n",
" threshold_low = np.percentile(mean_spectrum, 5)\n",
" threshold_high = np.percentile(mean_spectrum, 95)\n",
"\n",
" # Remove axis, labels, and other decorations\n",
" ax.axis('off')\n",
" plt.subplots_adjust(left=0, right=1, top=1, bottom=0)\n",
" freq_indices = np.where(mean_spectrum > threshold_low)\n",
" freq_min = 20\n",
" freq_max = frequencies[freq_indices].max()\n",
"\n",
" # Save the figure to a temporary buffer\n",
" plt.savefig(output_file, format='png', bbox_inches='tight', pad_inches=0)\n",
" fig, ax = plt.subplots()\n",
" cmap = plt.get_cmap(\"magma\")\n",
"\n",
" ax.pcolormesh(times, frequencies, log_Sxx, shading=\"gouraud\", cmap=cmap, vmin=vmin, vmax=vmax)\n",
" ax.axis(\"off\")\n",
" ax.set_ylim([freq_min, freq_max])\n",
"\n",
" # Close the plot\n",
" plt.subplots_adjust(left=0, right=1, top=1, bottom=0)\n",
" plt.savefig(output_file, format=\"png\", bbox_inches=\"tight\", pad_inches=0)\n",
" plt.close()\n",
"\n",
" # Return the image saved in the buffer\n",
" return wandb.Image(output_file)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 140
},
"id": "4fcOpiYx7Fqf",
"outputId": "4282158f-1170-4a3b-e514-a4f3916e3a7e"
"id": "4fcOpiYx7Fqf"
},
"outputs": [],
"source": [
Expand Down Expand Up @@ -253,4 +237,4 @@
},
"nbformat": 4,
"nbformat_minor": 0
}
}

0 comments on commit 6661b0b

Please sign in to comment.