diff --git a/colabs/audiocraft/AudioCraft_MusicGen.ipynb b/colabs/audiocraft/AudioCraft_MusicGen.ipynb index 65186293..134d7c91 100644 --- a/colabs/audiocraft/AudioCraft_MusicGen.ipynb +++ b/colabs/audiocraft/AudioCraft_MusicGen.ipynb @@ -4,12 +4,7 @@ "cell_type": "code", "execution_count": null, "metadata": { - "colab": { - "base_uri": "https://localhost:8080/", - "height": 1000 - }, - "id": "EZU3hg4B1om6", - "outputId": "1512f321-698c-4d01-e8d9-6d9577eee7e1" + "id": "EZU3hg4B1om6" }, "outputs": [], "source": [ @@ -44,12 +39,7 @@ "cell_type": "code", "execution_count": null, "metadata": { - "colab": { - "base_uri": "https://localhost:8080/", - "height": 122 - }, - "id": "3MTX8GoE7AzN", - "outputId": "33e92860-9ff9-4923-87a5-66c2d89e5f22" + "id": "3MTX8GoE7AzN" }, "outputs": [], "source": [ @@ -119,11 +109,7 @@ "cell_type": "code", "execution_count": null, "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "SfM8rhVX7ES9", - "outputId": "22fd7b60-ce23-4aac-86cf-5d5e0a712815" + "id": "SfM8rhVX7ES9" }, "outputs": [], "source": [ @@ -157,29 +143,32 @@ "outputs": [], "source": [ "def get_spectrogram(audio_file, output_file):\n", - " # Read the audio file\n", " sample_rate, samples = wavfile.read(audio_file)\n", - "\n", - " # Compute the spectrogram\n", " frequencies, times, Sxx = signal.spectrogram(samples, sample_rate)\n", "\n", - " # Create a figure and axis for plotting\n", - " fig, ax = plt.subplots()\n", + " log_Sxx = 10 * np.log10(Sxx + 1e-10)\n", + " vmin = np.percentile(log_Sxx, 5)\n", + " vmax = np.percentile(log_Sxx, 95)\n", "\n", - " # Plot the spectrogram\n", - " ax.pcolormesh(times, frequencies, 10 * np.log10(Sxx), shading='gouraud')\n", + " mean_spectrum = np.mean(log_Sxx, axis=1)\n", + " threshold_low = np.percentile(mean_spectrum, 5)\n", + " threshold_high = np.percentile(mean_spectrum, 95)\n", "\n", - " # Remove axis, labels, and other decorations\n", - " ax.axis('off')\n", - " plt.subplots_adjust(left=0, right=1, top=1, bottom=0)\n", + " freq_indices = np.where(mean_spectrum > threshold_low)\n", + " freq_min = 20\n", + " freq_max = frequencies[freq_indices].max()\n", "\n", - " # Save the figure to a temporary buffer\n", - " plt.savefig(output_file, format='png', bbox_inches='tight', pad_inches=0)\n", + " fig, ax = plt.subplots()\n", + " cmap = plt.get_cmap(\"magma\")\n", + "\n", + " ax.pcolormesh(times, frequencies, log_Sxx, shading=\"gouraud\", cmap=cmap, vmin=vmin, vmax=vmax)\n", + " ax.axis(\"off\")\n", + " ax.set_ylim([freq_min, freq_max])\n", "\n", - " # Close the plot\n", + " plt.subplots_adjust(left=0, right=1, top=1, bottom=0)\n", + " plt.savefig(output_file, format=\"png\", bbox_inches=\"tight\", pad_inches=0)\n", " plt.close()\n", "\n", - " # Return the image saved in the buffer\n", " return wandb.Image(output_file)" ] }, @@ -187,12 +176,7 @@ "cell_type": "code", "execution_count": null, "metadata": { - "colab": { - "base_uri": "https://localhost:8080/", - "height": 140 - }, - "id": "4fcOpiYx7Fqf", - "outputId": "4282158f-1170-4a3b-e514-a4f3916e3a7e" + "id": "4fcOpiYx7Fqf" }, "outputs": [], "source": [ @@ -253,4 +237,4 @@ }, "nbformat": 4, "nbformat_minor": 0 -} +} \ No newline at end of file