From aaf5bbf30d09e594304f52c11d4750f29dfd1f5b Mon Sep 17 00:00:00 2001 From: Pranit Shah <11985324+Pshah2023@users.noreply.github.com> Date: Thu, 28 Sep 2023 00:58:49 -0400 Subject: [PATCH] Improve pylint score, readme, and actions --- .github/workflows/{pylint.yml => py.yml} | 16 ++- README.md | 39 +++--- tuneease/getmusic/data/bigdata.py | 6 +- tuneease/getmusic/distributed/launch.py | 2 +- tuneease/getmusic/engine/solver.py | 12 +- .../modeling/roformer/diffusion_roformer.py | 8 +- tuneease/getmusic/utils/midi_config.py | 46 +++---- tuneease/getmusic/utils/misc.py | 4 +- tuneease/pathutility.py | 8 +- tuneease/pipeline/encoding.py | 114 +++++++----------- tuneease/pipeline/encoding_helpers.py | 32 ++--- tuneease/pipeline/file.py | 89 +------------- tuneease/pipeline/key_chord.py | 52 ++++---- tuneease/preprocess/binarize.py | 10 +- tuneease/preprocess/to_oct.py | 20 +-- tuneease/tests/memory.py | 11 ++ tuneease/tests/test_logger.py | 2 +- tuneease/tests/test_pathutility.py | 4 +- tuneease/tests/test_tuneease.py | 6 +- 19 files changed, 196 insertions(+), 285 deletions(-) rename .github/workflows/{pylint.yml => py.yml} (56%) create mode 100644 tuneease/tests/memory.py diff --git a/.github/workflows/pylint.yml b/.github/workflows/py.yml similarity index 56% rename from .github/workflows/pylint.yml rename to .github/workflows/py.yml index 2a45454..73ca894 100644 --- a/.github/workflows/pylint.yml +++ b/.github/workflows/py.yml @@ -1,4 +1,4 @@ -name: Pylint +name: Pylint, Pytest, and PyPI on: [push] @@ -7,7 +7,7 @@ jobs: runs-on: ubuntu-latest strategy: matrix: - python-version: ["3.8", "3.9", "3.10"] + python-version: ["3.7", "3.8", "3.9", "3.10", "3.11"] steps: - uses: actions/checkout@v3 - name: Set up Python ${{ matrix.python-version }} @@ -17,11 +17,23 @@ jobs: - name: Install dependencies run: | python -m pip install --upgrade pip + pip install -r requirements.txt pip install pylint - name: Analysing the code with pylint run: | pylint $(git ls-files '*.py') + continue-on-error: true - name: actions-pytest uses: xoviat/actions-pytest@0.1-alpha2 with: args: ./tuneease + continue-on-error: true + - name: Try pip installation + run: | + pip install -e . + continue-on-error: true + - name: Memory_Profiler + run: | + pip install memory_profiler + python -m memory_profiler ./tuneease/tests/memory.py + continue-on-error: true diff --git a/README.md b/README.md index d10af6b..ebb9f26 100644 --- a/README.md +++ b/README.md @@ -13,7 +13,7 @@ Welcome to Tune Ease! This project provides intuitive website control and automa **Do not run this without a virtual environment!** 1. Create a virtual environment using Python 3: ```sh - python3 -m venv venv + python3 -m venv env ``` 2. Activate the virtual environment: - On macOS/Linux: `source venv/bin/activate` @@ -72,31 +72,28 @@ Feel free to choose the setup option that best suits your needs. Enjoy the proje 1. Introduced the ability to use a CPU or a GPU. 2. Structural changes to the structure. 1. Instead of using a yaml to instantiate the model, I used a new train.py with a variable that contains the data from train.yaml - 2. Created a folder named pipeline that contains code that was previously duplicated many times. (Quite painful...) - 3. Improved the Pylint score from 1.25 to 2.46 + 2. Created a folder named pipeline that contains code that was previously duplicated many times. + 3. Prevented duplicated runs of the same code with new classes + 4. Improved the Pylint score from 1.25 to 3.90 3. Changed hierarchy and propagated changes. -4. Packaged this into a code shareable repo. This mean the install time has been reduced from the initial 4-5 hours it took me in the beginning to an automatic install process offered here. +4. Packaged this into a code shareable repo. This mean the install time has been reduced from the initial 4-5 hours it took me in the beginning to the more automatic install process offered here. # Usage -If you cloned the repo: -```sh -python -m tuneease.tuneease # generates a file and prints the location to the file -python -m tuneease.server # starts the server for you to use the app through the port that is printed -``` -If you installed through pip: -```sh -tuneease-generate # generates something and prints the location of the generated file -tuneease # starts the server and prints the port to go to for the website +Attempts are made to find your MuseScore and checkpoint path automatically. **Optionally, include your MuseScore and checkpoint path with the flag.** You do not need MuseScore most of the time, but you will need the checkpoint for the AI music. + +Use the below to generate one file. +```bash +# If you installed through pip +tuneease-generate --checkpoint_path +# If you installed through code sharing +python -m tuneease.tuneease --checkpoint_path ``` -Attempts are made to find your MuseScore and checkpoint path automatically. **Optionally, include your MuseScore and checkpoint path with the flag** -```sh -# For just generating one item -tuneease-generate --museScore_path --checkpoint_path -python -m tuneease.tuneease --museScore_path --checkpoint_path -# For the server -tuneease --museScore_path --checkpoint_path -python -m tuneease.server --museScore_path --checkpoint_path + +If you want to use the server, the website is mainly for using the AI. +```bash +tuneease --museScore_path --checkpoint_path +python -m tuneease.server --museScore_path --checkpoint_path ``` Normally, you can access the server at http://localhost:8080, or you can follow whatever port flask tells you to go to. diff --git a/tuneease/getmusic/data/bigdata.py b/tuneease/getmusic/data/bigdata.py index 488e3b2..c749c0d 100644 --- a/tuneease/getmusic/data/bigdata.py +++ b/tuneease/getmusic/data/bigdata.py @@ -1,9 +1,9 @@ -from torch.utils.data import Dataset -import numpy as np -import torch import random import itertools as it from ...getmusic.data.indexed_datasets import IndexedDataset +import numpy as np +import torch +from torch.utils.data import Dataset class BigDataset(Dataset): def __init__(self, prefix, vocab_size, path=None): diff --git a/tuneease/getmusic/distributed/launch.py b/tuneease/getmusic/distributed/launch.py index c25a23d..c0fa750 100644 --- a/tuneease/getmusic/distributed/launch.py +++ b/tuneease/getmusic/distributed/launch.py @@ -1,7 +1,7 @@ import torch from torch import distributed as dist from torch import multiprocessing as mp -import distributed as dist_fn +from . import distributed as dist_fn def find_free_port(): import socket diff --git a/tuneease/getmusic/engine/solver.py b/tuneease/getmusic/engine/solver.py index 46d98af..e23b770 100644 --- a/tuneease/getmusic/engine/solver.py +++ b/tuneease/getmusic/engine/solver.py @@ -49,13 +49,13 @@ def __init__(self, gpu, config, args, model, dataloader, logger, is_sample=False self.logger.debug('Load dictionary: {} tokens.'.format(len(self.t_h.ids_to_tokens))) - beat_note_factor =mc.beat_note_factor - max_notes_per_bar = mc.max_notes_per_bar - pos_resolution = mc.pos_resolution - bar_max = mc.bar_max + beat_note_factor =mc.BEAT_NOTE_FACTOR + max_notes_per_bar = mc.MAX_NOTES_PER_BAR + pos_resolution = mc.POS_RESOLUTION + bar_max = mc.BAR_MAX self.pos_in_bar = beat_note_factor * max_notes_per_bar * pos_resolution - self.pad_index = mc.duration_max * mc.pos_resolution - 1 - self.figure_size = mc.bar_max * mc.beat_note_factor * mc.max_notes_per_bar * mc.pos_resolution + self.pad_index = mc.DURATION_MAX * mc.POS_RESOLUTION - 1 + self.figure_size = mc.BAR_MAX * mc.BEAT_NOTE_FACTOR * mc.MAX_NOTES_PER_BAR * mc.POS_RESOLUTION if 'clip_grad_norm' in config['solver']: self.clip_grad_norm = instantiate_from_config(config['solver']['clip_grad_norm']) diff --git a/tuneease/getmusic/modeling/roformer/diffusion_roformer.py b/tuneease/getmusic/modeling/roformer/diffusion_roformer.py index 9e78dfe..b539a8d 100644 --- a/tuneease/getmusic/modeling/roformer/diffusion_roformer.py +++ b/tuneease/getmusic/modeling/roformer/diffusion_roformer.py @@ -72,8 +72,8 @@ def __init__( self.num_classes = self.roformer.vocab_size + 1 # defined in vocabulary, add an additional mask self.cond_weight = self.roformer.cond_weight self.tracks = 14 - self.pad_index = mc.duration_max * mc.pos_resolution - 1 - self.figure_size = mc.bar_max * mc.beat_note_factor * mc.max_notes_per_bar * mc.pos_resolution + self.pad_index = mc.DURATION_MAX * mc.POS_RESOLUTION - 1 + self.figure_size = mc.BAR_MAX * mc.BEAT_NOTE_FACTOR * mc.MAX_NOTES_PER_BAR * mc.POS_RESOLUTION self.num_timesteps = diffusion_step self.parametrization = 'x0' self.auxiliary_loss_weight = auxiliary_loss_weight @@ -174,8 +174,8 @@ def log_sample_categorical_infer(self, logits, figure_size): # use gum if i % 2 == 1: # duration track[:, self.pad_index+1:-1, :] = -70 # only decode duration tokens else: # only decode pitch tokens in $i$-th track - start = mc.tracks_start[i // 2] - end = mc.tracks_end[i // 2] + start = mc.TRACKS_START[i // 2] + end = mc.TRACKS_END[i // 2] track[:,:self.pad_index, :] = -70 track[:,self.pad_index+1:start,:] = -70 track[:,end+1:-1,:] = -70 diff --git a/tuneease/getmusic/utils/midi_config.py b/tuneease/getmusic/utils/midi_config.py index f0dd2c0..444111d 100644 --- a/tuneease/getmusic/utils/midi_config.py +++ b/tuneease/getmusic/utils/midi_config.py @@ -1,23 +1,23 @@ -pos_resolution = 4 # 16 # per beat (quarter note) -bar_max = 32 -velocity_quant = 4 -tempo_quant = 12 # 2 ** (1 / 12) -min_tempo = 16 -max_tempo = 256 -duration_max = 4 # 2 ** 8 * beat -max_ts_denominator = 6 # x/1 x/2 x/4 ... x/64 -max_notes_per_bar = 1 # 1/64 ... 128/64 # -beat_note_factor = 4 # In MIDI format a note is always 4 beats -deduplicate = True -filter_symbolic = False -filter_symbolic_ppl = 16 -trunc_pos = 2 ** 16 # approx 30 minutes (1024 measures) -sample_len_max = 1024 # window length max -sample_overlap_rate = 1.5 -ts_filter = True -pool_num = 200 -max_inst = 127 -max_pitch = 127 -max_velocity = 127 -tracks_start = [16, 144, 997, 5366, 6921, 10489] -tracks_end = [143, 996, 5365, 6920, 10488, 11858] +POS_RESOLUTION = 4 # 16 # per beat (quarter note) +BAR_MAX = 32 +VELOCITY_QUANT = 4 +TEMPO_QUANT = 12 # 2 ** (1 / 12) +MIN_TEMPO = 16 +MAX_TEMPO = 256 +DURATION_MAX = 4 # 2 ** 8 * beat +MAX_TS_DENOMINATOR = 6 # x/1 x/2 x/4 ... x/64 +MAX_NOTES_PER_BAR = 1 # 1/64 ... 128/64 # +BEAT_NOTE_FACTOR = 4 # In MIDI format a note is always 4 beats +DEDUPLICATE = True +FILTER_SYMBOLIC = False +FILTER_SYMBOLIC_PPL = 16 +TRUNC_POS = 2 ** 16 # approx 30 minutes (1024 measures) +SAMPLE_LEN_MAX = 1024 # window length max +SAMPLE_OVERLAP_RATE = 1.5 +TS_FILTER = True +POOL_NUM = 200 +MAX_INST = 127 +MAX_PITCH = 127 +MAX_VELOCITY = 127 +TRACKS_START = [16, 144, 997, 5366, 6921, 10489] +TRACKS_END = [143, 996, 5365, 6920, 10488, 11858] diff --git a/tuneease/getmusic/utils/misc.py b/tuneease/getmusic/utils/misc.py index b769922..13751da 100644 --- a/tuneease/getmusic/utils/misc.py +++ b/tuneease/getmusic/utils/misc.py @@ -1,9 +1,9 @@ import importlib import random -import numpy as np -import torch import warnings import os +import numpy as np +import torch def seed_everything(seed, cudnn_deterministic=False): """ diff --git a/tuneease/pathutility.py b/tuneease/pathutility.py index 36b18b1..e5a6120 100644 --- a/tuneease/pathutility.py +++ b/tuneease/pathutility.py @@ -163,18 +163,16 @@ def musescore_path(self, printsuccess = False): return filename self.logger.info("On Windows and macOS, install MuseScore through the website:") self.logger.info("[https://musescore.org/en](https://musescore.org/en).") - self.logger.info("Attempts are made to find it automatically. Tested for windows.") - self.logger.info("If necessary, use --museScore_path ") + self.logger.info("Attempts are made to find it automatically. If necessary, use --musescore_path ") return None def checkpoint_path(self, printsuccess = False): path = os.path.join(self.project_directory(), "checkpoint.pth") if not os.path.exists(path): - self.logger.info("You have to install the below with the following path:") + self.logger.info("Install the below at the following path or use --checkpoint_path ") fileurl = 'https://1drv.ms/u/s!ArHNvccy1VzPkWGKXZDQY5k-kDi4?e=fFxcEq' self.logger.info(fileurl) - toprint = fileurl + f"{path}" - self.logger.info(toprint) + self.logger.info(path) return None else: if printsuccess: diff --git a/tuneease/pipeline/encoding.py b/tuneease/pipeline/encoding.py index b59e66d..e9c8334 100644 --- a/tuneease/pipeline/encoding.py +++ b/tuneease/pipeline/encoding.py @@ -1,6 +1,6 @@ import miditoolkit import math -from ..getmusic.utils.midi_config import pos_resolution, beat_note_factor, trunc_pos, max_pitch, max_inst, filter_symbolic, filter_symbolic_ppl +from ..getmusic.utils.midi_config import POS_RESOLUTION, BEAT_NOTE_FACTOR, TRUNC_POS, MAX_PITCH, MAX_INST, FILTER_SYMBOLIC, FILTER_SYMBOLIC_PPL from .encoding_helpers import t2e, e2t, d2e, e2d, time_signature_reduce, b2e, e2b, v2e, e2v from .presets import RootKinds from .item import Item @@ -14,7 +14,7 @@ def MIDI_to_encoding(filename, midi_obj, with_chord = None, condition_inst = Non r_h = RootKinds() if filename == "track_generation.py": def time_to_pos(t): - return round(t * pos_resolution / midi_obj.ticks_per_beat) + return round(t * POS_RESOLUTION / midi_obj.ticks_per_beat) notes_start_pos = [time_to_pos(j.start) for i in midi_obj.instruments for j in i.notes] if len(notes_start_pos) == 0: @@ -43,9 +43,9 @@ def time_to_pos(t): bar = 0 measure_length = None for j in range(len(pos_to_info)): - ts = e2t(pos_to_info[j][1]) + timesignature = e2t(pos_to_info[j][1]) if cnt == 0: - measure_length = ts[0] * beat_note_factor * pos_resolution // ts[1] + measure_length = timesignature[0] * BEAT_NOTE_FACTOR * POS_RESOLUTION // timesignature[1] pos_to_info[j][0] = bar pos_to_info[j][2] = cnt cnt += 1 @@ -58,18 +58,16 @@ def time_to_pos(t): for inst in midi_obj.instruments: for note in inst.notes: - if time_to_pos(note.start) >= trunc_pos: + if time_to_pos(note.start) >= TRUNC_POS: continue - info = pos_to_info[time_to_pos(note.start)] duration = d2e(time_to_pos(note.end) - time_to_pos(note.start)) - encoding.append([info[0], info[2], max_inst + 1 if inst.is_drum else inst.program, note.pitch + max_pitch + + encoding.append([info[0], info[2], MAX_INST + 1 if inst.is_drum else inst.program, note.pitch + MAX_PITCH + 1 if inst.is_drum else note.pitch, duration, v2e(note.velocity), info[1], info[3]]) if len(encoding) == 0: return list() encoding.sort() encoding, is_major, pitch_shift = normalize_to_c_major(filename, encoding) - # extract chords if with_chord: max_pos = 0 @@ -78,9 +76,8 @@ def time_to_pos(t): if (0 < note[3] < 128) and (note[2] in [0,25,32,48,80]): if chord_from_single and (str(note[2]) not in condition_inst): continue - - ts = e2t(note[6]) - measure_length = ts[0] * beat_note_factor * pos_resolution // ts[1] + timesignature = e2t(note[6]) + measure_length = timesignature[0] * BEAT_NOTE_FACTOR * POS_RESOLUTION // timesignature[1] max_pos = max( max_pos, measure_length * note[0] + note[1] + e2d(note[4])) note_items.append(Item( @@ -101,32 +98,29 @@ def time_to_pos(t): key_chord_transition_loglik=k_c_d.key_chord_transition_loglik ) else: - chords = [] - + chords = [] bar_idx = 0 - for c in chords: - if c == 'N.C.': + for chord in chords: + if chord == 'N.C.': bar_idx+=1 continue - r, k = c.split(':') + r, k = chord.split(':') if k == '': k = 'null' elif k == '7': k = 'seven' encoding.append((bar_idx, 0, 129, r_h.root_dict[r], r_h.kind_dict[k], 0, t2e(time_signature_reduce(4, 4)), 0)) bar_idx += 1 - encoding.sort() return encoding, pitch_shift, tpc elif filename == "position_generation.py": def time_to_pos(t): - return round(t * pos_resolution / midi_obj.ticks_per_beat) + return round(t * POS_RESOLUTION / midi_obj.ticks_per_beat) notes_start_pos = [time_to_pos(j.start) for i in midi_obj.instruments for j in i.notes] if len(notes_start_pos) == 0: return list() max_pos = max(notes_start_pos) + 1 - pos_to_info = [[None for _ in range(4)] for _ in range( max_pos)] # (Measure, TimeSig, Pos, Tempo) tsc = midi_obj.time_signature_changes # [TimeSignature(numerator=4, denominator=4, time=0)] @@ -150,9 +144,9 @@ def time_to_pos(t): bar = 0 measure_length = None for j in range(len(pos_to_info)): # 它这里是不管这个位置有没有音符,都占个位 - ts = e2t(pos_to_info[j][1]) + timesignature = e2t(pos_to_info[j][1]) if cnt == 0: - measure_length = ts[0] * beat_note_factor * pos_resolution // ts[1] # 比如一个3/4的ts,一个4/4的小节有16pos,所以3/4一小节就有12 + measure_length = timesignature[0] * BEAT_NOTE_FACTOR * POS_RESOLUTION // timesignature[1] # 比如一个3/4的ts,一个4/4的小节有16pos,所以3/4一小节就有12 pos_to_info[j][0] = bar pos_to_info[j][2] = cnt cnt += 1 @@ -162,31 +156,26 @@ def time_to_pos(t): cnt -= measure_length bar += 1 encoding = [] - for inst in midi_obj.instruments: for note in inst.notes: - if time_to_pos(note.start) >= trunc_pos: + if time_to_pos(note.start) >= TRUNC_POS: continue - info = pos_to_info[time_to_pos(note.start)] duration = d2e(time_to_pos(note.end) - time_to_pos(note.start)) - encoding.append([info[0], info[2], max_inst + 1 if inst.is_drum else inst.program, note.pitch + max_pitch + + encoding.append([info[0], info[2], MAX_INST + 1 if inst.is_drum else inst.program, note.pitch + MAX_PITCH + 1 if inst.is_drum else note.pitch, duration, v2e(note.velocity), info[1], info[3]]) if len(encoding) == 0: return list() - encoding.sort() encoding, is_major, pitch_shift = normalize_to_c_major(filename, encoding) - - # extract chords if with_chord: max_pos = 0 note_items = [] for note in encoding: if 0 < note[3] < 128: # and str(note[2]) in condition_inst: - ts = e2t(note[6]) - measure_length = ts[0] * beat_note_factor * pos_resolution // ts[1] + timesignature = e2t(note[6]) + measure_length = timesignature[0] * BEAT_NOTE_FACTOR * POS_RESOLUTION // timesignature[1] max_pos = max( max_pos, measure_length * note[0] + note[1] + e2d(note[4])) note_items.append(Item( @@ -208,40 +197,35 @@ def time_to_pos(t): ) else: chords = [] - bar_idx = 0 - for c in chords: - if c == 'N.C.': + for chord in chords: + if chord == 'N.C.': bar_idx+=1 continue - r, k = c.split(':') + r, k = chord.split(':') if k == '': k = 'null' elif k == '7': k = 'seven' encoding.append((bar_idx, 0, 129, r_h.root_dict[r], r_h.kind_dict[k], 0, t2e(time_signature_reduce(4, 4)), 0)) bar_idx += 1 - encoding.sort() - return encoding, pitch_shift, tpc elif filename == "to_oct.py": def time_to_pos(t): - return round(t * pos_resolution / midi_obj.ticks_per_beat) + return round(t * POS_RESOLUTION / midi_obj.ticks_per_beat) notes_start_pos = [time_to_pos(j.start) for i in midi_obj.instruments for j in i.notes] if len(notes_start_pos) == 0: return list() - max_pos = min(max(notes_start_pos) + 1, trunc_pos) + max_pos = min(max(notes_start_pos) + 1, TRUNC_POS) pos_to_info = [[None for _ in range(4)] for _ in range( max_pos)] # (Measure, TimeSig, Pos, Tempo) tsc = midi_obj.time_signature_changes # [TimeSignature(numerator=4, denominator=4, time=0)] tpc = midi_obj.tempo_changes # [TempoChange(tempo=120.0, time=0)] - # filter tempo and ts change if len(tsc) > 1 or len(tpc) > 1: return ['welcome use my code'] - for i in range(len(tsc)): for j in range(time_to_pos(tsc[i].time), time_to_pos(tsc[i + 1].time) if i < len(tsc) - 1 else max_pos): if j < len(pos_to_info): @@ -261,9 +245,9 @@ def time_to_pos(t): bar = 0 measure_length = None for j in range(len(pos_to_info)): - ts = e2t(pos_to_info[j][1]) + timesignature = e2t(pos_to_info[j][1]) if cnt == 0: - measure_length = ts[0] * beat_note_factor * pos_resolution // ts[1] + measure_length = timesignature[0] * BEAT_NOTE_FACTOR * POS_RESOLUTION // timesignature[1] pos_to_info[j][0] = bar pos_to_info[j][2] = cnt cnt += 1 @@ -273,39 +257,35 @@ def time_to_pos(t): cnt -= measure_length bar += 1 encoding = [] - start_distribution = [0] * pos_resolution - + start_distribution = [0] * POS_RESOLUTION for inst in midi_obj.instruments: for note in inst.notes: - if time_to_pos(note.start) >= trunc_pos: + if time_to_pos(note.start) >= TRUNC_POS: continue - start_distribution[time_to_pos(note.start) % pos_resolution] += 1 + start_distribution[time_to_pos(note.start) % POS_RESOLUTION] += 1 info = pos_to_info[time_to_pos(note.start)] duration = d2e(time_to_pos(note.end) - time_to_pos(note.start)) - encoding.append((info[0], info[2], max_inst + 1 if inst.is_drum else inst.program, note.pitch + max_pitch + + encoding.append((info[0], info[2], MAX_INST + 1 if inst.is_drum else inst.program, note.pitch + MAX_PITCH + 1 if inst.is_drum else note.pitch, duration, v2e(note.velocity), info[1], info[3])) if len(encoding) == 0: return list() - tot = sum(start_distribution) start_ppl = 2 ** sum((0 if x == 0 else -(x / tot) * math.log2((x / tot)) for x in start_distribution)) # filter unaligned music - if filter_symbolic: - assert start_ppl <= filter_symbolic_ppl, 'filtered out by the symbolic filter: ppl = {:.2f}'.format( + if FILTER_SYMBOLIC: + assert start_ppl <= FILTER_SYMBOLIC_PPL, 'filtered out by the symbolic filter: ppl = {:.2f}'.format( start_ppl) - # normalize encoding.sort() encoding, is_major = normalize_to_c_major(filename, encoding) - # extract chords max_pos = 0 note_items = [] for note in encoding: if 0 <= note[3] < 128: - ts = e2t(note[6]) - measure_length = ts[0] * beat_note_factor * pos_resolution // ts[1] + timesignature = e2t(note[6]) + measure_length = timesignature[0] * BEAT_NOTE_FACTOR * POS_RESOLUTION // timesignature[1] max_pos = max( max_pos, measure_length * note[0] + note[1] + e2d(note[4])) note_items.append(Item( @@ -324,22 +304,18 @@ def time_to_pos(t): key_chord_loglik=k_c_d.key_chord_loglik, key_chord_transition_loglik=k_c_d.key_chord_transition_loglik ) - bar_idx = 0 - for c in chords: - r, k = c.split(':') + for chord in chords: + r, k = chord.split(':') if k == '': k = 'null' elif k == '7': k = 'seven' encoding.append((bar_idx, 0, 129, r_h.root_dict[r], r_h.kind_dict[k], 0, t2e(time_signature_reduce(4, 4)), 0)) bar_idx += 1 - encoding.sort() return encoding - - def encoding_to_MIDI(filename, encoding, tpc = None, decode_chord=None): magenta = Magenta() r_h = RootKinds() @@ -365,7 +341,7 @@ def encoding_to_MIDI(filename, encoding, tpc = None, decode_chord=None): for i in range(len(bar_to_pos)): bar_to_pos[i] = cur_pos ts = e2t(bar_to_timesig[i]) - measure_length = ts[0] * beat_note_factor * pos_resolution // ts[1] + measure_length = ts[0] * BEAT_NOTE_FACTOR * POS_RESOLUTION // ts[1] cur_pos += measure_length pos_to_tempo = [list() for _ in range( cur_pos + max(map(lambda x: x[1], encoding)))] @@ -379,7 +355,7 @@ def encoding_to_MIDI(filename, encoding, tpc = None, decode_chord=None): midi_obj = miditoolkit.midi.parser.MidiFile() midi_obj.tempo_changes = tpc def get_tick(bar, pos): - return (bar_to_pos[bar] + pos) * midi_obj.ticks_per_beat // pos_resolution + return (bar_to_pos[bar] + pos) * midi_obj.ticks_per_beat // POS_RESOLUTION midi_obj.instruments = [miditoolkit.containers.Instrument(program=( 0 if i == 128 else i), is_drum=(i == 128), name=str(i)) for i in range(128 + 1)] for i in encoding: @@ -442,7 +418,7 @@ def get_tick(bar, pos): for i in range(len(bar_to_pos)): bar_to_pos[i] = cur_pos ts = e2t(bar_to_timesig[i]) - measure_length = ts[0] * beat_note_factor * pos_resolution // ts[1] + measure_length = ts[0] * BEAT_NOTE_FACTOR * POS_RESOLUTION // ts[1] cur_pos += measure_length pos_to_tempo = [list() for _ in range( cur_pos + max(map(lambda x: x[1], encoding)))] @@ -454,9 +430,8 @@ def get_tick(bar, pos): if pos_to_tempo[i] is None: pos_to_tempo[i] = b2e(120.0) if i == 0 else pos_to_tempo[i - 1] midi_obj = miditoolkit.midi.parser.MidiFile() - def get_tick(bar, pos): - return (bar_to_pos[bar] + pos) * midi_obj.ticks_per_beat // pos_resolution + return (bar_to_pos[bar] + pos) * midi_obj.ticks_per_beat // POS_RESOLUTION midi_obj.instruments = [miditoolkit.containers.Instrument(program=( 0 if i == 128 else i), is_drum=(i == 128), name=str(i)) for i in range(128 + 1)] for i in encoding: @@ -512,7 +487,7 @@ def get_tick(bar, pos): for i in range(len(bar_to_pos)): bar_to_pos[i] = cur_pos ts = e2t(bar_to_timesig[i]) - measure_length = ts[0] * beat_note_factor * pos_resolution // ts[1] + measure_length = ts[0] * BEAT_NOTE_FACTOR * POS_RESOLUTION // ts[1] cur_pos += measure_length pos_to_tempo = [list() for _ in range( cur_pos + max(map(lambda x: x[1], encoding)))] @@ -523,19 +498,15 @@ def get_tick(bar, pos): for i in range(len(pos_to_tempo)): if pos_to_tempo[i] is None: pos_to_tempo[i] = b2e(120.0) if i == 0 else pos_to_tempo[i - 1] - midi_obj = miditoolkit.midi.parser.MidiFile() midi_obj.tempo_changes = tpc - def get_tick(bar, pos): - return (bar_to_pos[bar] + pos) * midi_obj.ticks_per_beat // pos_resolution + return (bar_to_pos[bar] + pos) * midi_obj.ticks_per_beat // POS_RESOLUTION midi_obj.instruments = [miditoolkit.containers.Instrument(program=( 0 if i == 128 else i), is_drum=(i == 128), name=str(i)) for i in range(128 + 1)] - for i in encoding: start = get_tick(i[0], i[1]) program = i[2] - if program == 129 and decode_chord: root_name = r_h.root_list[i[3]] kind_name = r_h.kind_list[i[4]] @@ -554,7 +525,6 @@ def get_tick(bar, pos): duration = 1 end = start + duration velocity = e2v(i[5]) - midi_obj.instruments[program].notes.append(miditoolkit.containers.Note( start=start, end=end, pitch=pitch, velocity=velocity)) midi_obj.instruments = [ diff --git a/tuneease/pipeline/encoding_helpers.py b/tuneease/pipeline/encoding_helpers.py index ebf0141..5e6505a 100644 --- a/tuneease/pipeline/encoding_helpers.py +++ b/tuneease/pipeline/encoding_helpers.py @@ -1,4 +1,4 @@ -from ..getmusic.utils.midi_config import max_ts_denominator, max_notes_per_bar, duration_max, pos_resolution, min_tempo, max_tempo, tempo_quant, velocity_quant +from ..getmusic.utils.midi_config import MAX_TS_DENOMINATOR, MAX_NOTES_PER_BAR, DURATION_MAX, POS_RESOLUTION, MIN_TEMPO, MAX_TEMPO, TEMPO_QUANT, VELOCITY_QUANT import math class TS: @@ -6,8 +6,8 @@ class TS: ts_list = list() def __init__(self): - for i in range(0, max_ts_denominator + 1): # 1 ~ 64 - for j in range(1, ((2 ** i) * max_notes_per_bar) + 1): + for i in range(0, MAX_TS_DENOMINATOR + 1): # 1 ~ 64 + for j in range(1, ((2 ** i) * MAX_NOTES_PER_BAR) + 1): self.ts_dict[(j, 2 ** i)] = len(self.ts_dict) self.ts_list.append((j, 2 ** i)) @@ -25,8 +25,8 @@ class Dur: dur_dec = list() def __init__(self) -> None: - for i in range(duration_max): - for j in range(pos_resolution): + for i in range(DURATION_MAX): + for j in range(POS_RESOLUTION): self.dur_dec.append(len(self.dur_enc)) for k in range(2 ** i): self.dur_enc.append(len(self.dur_dec) - 1) @@ -41,13 +41,13 @@ def e2d(x): def time_signature_reduce(numerator, denominator): # reduction (when denominator is too large) - global max_ts_denominator - global max_notes_per_bar - while denominator > 2 ** max_ts_denominator and denominator % 2 == 0 and numerator % 2 == 0: + global MAX_TS_DENOMINATOR + global MAX_NOTES_PER_BAR + while denominator > 2 ** MAX_TS_DENOMINATOR and denominator % 2 == 0 and numerator % 2 == 0: denominator //= 2 numerator //= 2 # decomposition (when length of a bar exceed max_notes_per_bar) - while numerator > max_notes_per_bar * denominator: + while numerator > MAX_NOTES_PER_BAR * denominator: for i in range(2, numerator + 1): if numerator % i == 0: numerator //= i @@ -55,17 +55,17 @@ def time_signature_reduce(numerator, denominator): return numerator, denominator def v2e(x): - return x // velocity_quant + return x // VELOCITY_QUANT def e2v(x): - return (x * velocity_quant) + (velocity_quant // 2) + return (x * VELOCITY_QUANT) + (VELOCITY_QUANT // 2) def b2e(x): - x = max(x, min_tempo) - x = min(x, max_tempo) - x = x / min_tempo - e = round(math.log2(x) * tempo_quant) + x = max(x, MIN_TEMPO) + x = min(x, MAX_TEMPO) + x = x / MIN_TEMPO + e = round(math.log2(x) * TEMPO_QUANT) return e def e2b(x): - return 2 ** (x / tempo_quant) * min_tempo + return 2 ** (x / TEMPO_QUANT) * MIN_TEMPO diff --git a/tuneease/pipeline/file.py b/tuneease/pipeline/file.py index 15ec6ad..4bcae3a 100644 --- a/tuneease/pipeline/file.py +++ b/tuneease/pipeline/file.py @@ -5,7 +5,7 @@ from .key_chord import TokenHelper, KeyChordDetails from .encoding_helpers import b2e, t2e, time_signature_reduce from .presets import inst_to_row, prog_to_abrv, RootKinds -from ..getmusic.utils.midi_config import bar_max, deduplicate, sample_len_max, sample_overlap_rate +from ..getmusic.utils.midi_config import BAR_MAX, DEDUPLICATE, SAMPLE_LEN_MAX, SAMPLE_OVERLAP_RATE from .file_helpers import get_midi_dict, timeout, get_hash, lock_set, lock_write, writer track_name = ['lead', 'bass', 'drum', 'guitar', 'piano', 'string'] @@ -31,60 +31,43 @@ def F(filename, file_name, conditional_tracks = None, content_tracks = None, con r_h = RootKinds() if filename == "track_generation.py": empty_tracks = ~conditional_tracks & ~content_tracks - conditional_tracks &= ~empty_tracks # emptied tracks can not be condition conditional_tracks = torch.tensor(conditional_tracks).float() conditional_tracks = conditional_tracks.view(7,1).repeat(1,2).reshape(14,1) empty_tracks = torch.tensor(empty_tracks).float() empty_tracks = empty_tracks.view(7,1).repeat(1,2).reshape(14,1) - midi_obj = miditoolkit.midi.parser.MidiFile(file_name) - if conditional_tracks[-1]: with_chord = True else: with_chord = False - - # try: encoding, pitch_shift, tpc = MIDI_to_encoding('track_generation.py', midi_obj, with_chord, condition_inst, chord_from_single) if len(encoding) == 0: print('ERROR(BLANK): ' + file_name + '\n', end='') return None, 0 bar_index_offset = 0 - figure_size = encoding[-1][0] * k_c_d.pos_in_bar + encoding[-1][1] pad_length = 1 #(512 - figure_size % 512) - figure_size += pad_length conditional_bool = conditional_tracks.repeat(1,figure_size) - empty_pos = empty_tracks.repeat(1, figure_size).type(torch.bool) datum = t_h.pad_index * torch.ones(14, figure_size, dtype=float) oov = 0 inv = 0 - chord_list = [] - tempo = b2e(67) - lead_start = 0 - idx = 0 while idx != len(encoding) - 1: e = encoding[idx] - bar = e[0] pos = e[1] inst = e[2] pitch = e[3] - if inst == 80: tempo = e[7] assert tempo != 0, 'bad tempo' - - # assert e[6] == 6 - if e[2] == 129: row = inst_to_row[str(inst)] r = r_h.root_list[e[3]] @@ -93,9 +76,7 @@ def F(filename, file_name, conditional_tracks = None, content_tracks = None, con datum[2 * row + 1][k_c_d.pos_in_bar * bar + pos : k_c_d.pos_in_bar * (bar + 1) + pos] = t_h.tokens_to_ids[k] idx += 1 continue - chord_list = [str(e[3])] - for f_idx in range(idx + 1, len(encoding)): if (encoding[f_idx][0] == bar) and (encoding[f_idx][1] == pos) and (encoding[f_idx][2] == inst): if encoding[f_idx][3] != pitch: @@ -103,23 +84,16 @@ def F(filename, file_name, conditional_tracks = None, content_tracks = None, con pitch = encoding[f_idx][3] else: break - idx = max(idx + 1, f_idx) - - dur = e[4] if dur == 0: continue - if not (str(inst) in inst_to_row): continue - row = inst_to_row[str(inst)] dur = t_h.tokens_to_ids['T'+str(e[4])] # duration - chord_string = ' '.join(chord_list) token = prog_to_abrv[str(inst)] + chord_string - if token in t_h.tokens_to_ids: pitch = t_h.tokens_to_ids[token] assert (dur < t_h.pad_index) and (pitch > t_h.pad_index), 'pitch index is {} and dur index is {}'.format(pitch, dur) @@ -128,74 +102,44 @@ def F(filename, file_name, conditional_tracks = None, content_tracks = None, con inv += 1 else: oov += 1 - datum = torch.where(empty_pos, t_h.empty_index, datum) - datum = torch.where(((datum != t_h.empty_index).float() * (1 - conditional_bool)).type(torch.bool), t_h.empty_index + 1, datum) - - # datum = datum[:,:1280] - # conditional_bool = conditional_bool[:,:1280] - - # # if trunc: - # datum = datum[:,:512] conditional_bool = conditional_bool[:,:512] - not_empty_pos = (torch.tensor(np.array(datum)) != t_h.empty_index).float() - have_cond = True - for i in range(14): if with_chord and conditional_tracks[i] == 1 and ((datum[i] == t_h.pad_index).sum() + (datum[i] == t_h.empty_index).sum()) == min(512,figure_size): have_cond = False break - return datum.unsqueeze(0), torch.tensor(tempo), not_empty_pos, conditional_bool, pitch_shift, tpc, have_cond elif filename == "position_generation.py": midi_obj = miditoolkit.midi.parser.MidiFile(file_name) - encoding, pitch_shift, tpc = MIDI_to_encoding("position_generation.py", midi_obj) - if len(encoding) == 0: print('ERROR(BLANK): ' + file_name + '\n', end='') return None, 0 - bar_index_offset = 0 - figure_size = max(encoding[-1][0] * k_c_d.pos_in_bar + encoding[-1][1], 512) - pad_length = 1 #(512 - figure_size % 512) - figure_size += pad_length - datum = t_h.pad_index * torch.ones(14, figure_size, dtype=float) - oov = 0 inv = 0 - chord_list = [] - tempo = b2e(67) - lead_start = 0 - idx = 0 - track_set = set() - while idx != len(encoding) - 1: e = encoding[idx] - bar = e[0] pos = e[1] inst = e[2] pitch = e[3] - if inst == 80: tempo = e[7] assert tempo != 0, 'bad tempo' - # assert e[6] == 6 - if e[2] == 129: row = inst_to_row[str(inst)] r = r_h.root_list[e[3]] @@ -204,9 +148,7 @@ def F(filename, file_name, conditional_tracks = None, content_tracks = None, con datum[2 * row + 1][k_c_d.pos_in_bar * bar + pos : k_c_d.pos_in_bar * (bar + 1) + pos] = t_h.tokens_to_ids[k] idx += 1 continue - chord_list = [str(e[3])] - for f_idx in range(idx + 1, len(encoding)): if (encoding[f_idx][0] == bar) and (encoding[f_idx][1] == pos) and (encoding[f_idx][2] == inst): if encoding[f_idx][3] != pitch: @@ -214,25 +156,17 @@ def F(filename, file_name, conditional_tracks = None, content_tracks = None, con pitch = encoding[f_idx][3] else: break - idx = max(idx + 1, f_idx) - - dur = e[4] if dur == 0: continue - if not (str(inst) in inst_to_row): continue - row = inst_to_row[str(inst)] dur = t_h.tokens_to_ids['T'+str(e[4])] # duration - chord_string = ' '.join(chord_list) token = prog_to_abrv[str(inst)] + chord_string - track_set.add(track_name[prog_to_abrv[str(inst)]]) - if token in t_h.tokens_to_ids: pitch = t_h.tokens_to_ids[token] assert (dur < t_h.pad_index) and (pitch > t_h.pad_index), 'pitch index is {} and dur index is {}'.format(pitch, dur) @@ -241,26 +175,20 @@ def F(filename, file_name, conditional_tracks = None, content_tracks = None, con inv += 1 else: oov += 1 - datum[:,-pad_length:] = t_h.empty_index - print('The music has {} tracks, with {} positions'.format(track_set, datum.size()[1])) print('Representation Visualization:') print('\t0,1,2,3,4,5,6,7,8,...\n(0)lead\n(1)bass\n(2)drum\n(3)guitar\n(4)piano\n(5)string\n(6)chord') print('Example: condition on 100 to 200 position of lead, 300 to 400 position of piano, write command like this:\'0,100,200;4,300,400') condition_str = input('Input positions you want to condition on:') empty_str = input('Input positions you want to empty:') - empty_pos = torch.zeros_like(datum) condition_pos = torch.zeros_like(datum) empty_pos = create_pos_from_str(empty_str, empty_pos) condition_pos = create_pos_from_str(condition_str, condition_pos) - datum = torch.where(empty_pos.type(torch.bool), t_h.empty_index, datum) datum = torch.where(((datum != t_h.empty_index).float() * (1 - condition_pos)).type(torch.bool), t_h.empty_index + 1, datum) - not_empty_pos = (torch.tensor(np.array(datum)) != t_h.empty_index).float() - return datum.unsqueeze(0), torch.tensor(tempo), not_empty_pos, condition_pos, pitch_shift, tpc elif filename == "to_oct.py": try: @@ -279,31 +207,26 @@ def F(filename, file_name, conditional_tracks = None, content_tracks = None, con if midi_notes_count == 0: print('ERROR(BLANK): ' + file_name + '\n', end='') return None - no_empty_tracks = {'80':0,'32':0,'128':0,'25':0,'0':0,'48':0} for inst in midi_obj.instruments: no_empty_tracks[str(inst.program)] = 1 - if no_empty_tracks['80'] == 0 or sum(no_empty_tracks.values()) <= 1: print('ERROR(BAD TRACKS): ' + file_name + '\n', end='') return False try: e = MIDI_to_encoding("to_oct.py", midi_obj) - if len(e) == 0: print('ERROR(BLANK): ' + file_name + '\n', end='') return None - if len(e) == 1: print('ERROR(TEMPO CHANGE): ' + file_name + '\n', end='') return False - # if ts_filter: allowed_ts = t2e(time_signature_reduce(4, 4)) if not all(i[6] == allowed_ts for i in e): print('ERROR(TSFILT): ' + file_name + '\n', end='') return None - if deduplicate: + if DEDUPLICATE: duplicated = False dup_file_name = '' midi_hash = '0' * 32 @@ -324,10 +247,10 @@ def F(filename, file_name, conditional_tracks = None, content_tracks = None, con file_name + ' == ' + dup_file_name + '\n', end='') return None output_str_list = [] - sample_step = max(round(sample_len_max / sample_overlap_rate), 1) + sample_step = max(round(SAMPLE_LEN_MAX / SAMPLE_OVERLAP_RATE), 1) for p in range(0, len(e), sample_step): L = p - R = min(p + sample_len_max, len(e)) - 1 + R = min(p + SAMPLE_LEN_MAX, len(e)) - 1 bar_index_list = [e[i][0] for i in range(L, R + 1) if e[i][0] is not None] bar_index_min = 0 @@ -335,19 +258,17 @@ def F(filename, file_name, conditional_tracks = None, content_tracks = None, con if len(bar_index_list) > 0: bar_index_min = min(bar_index_list) bar_index_max = max(bar_index_list) - # to make bar index start from 0 bar_index_offset = -bar_index_min e_segment = [] for i in e[L: R + 1]: - if i[0] is None or i[0] + bar_index_offset < bar_max: + if i[0] is None or i[0] + bar_index_offset < BAR_MAX: e_segment.append(i) else: break tokens_per_note = 8 output_words = ([('<{}-{}>'.format(j, k if j > 0 else k + bar_index_offset) if k is not None else '') for i in e_segment for j, k in enumerate(i)]) # tokens_per_note - 1 for append_eos functionality of binarizer in fairseq output_str_list.append(' '.join(output_words)) - # no empty if not all(len(i.split()) > tokens_per_note * 2 - 1 for i in output_str_list): print('ERROR(ENCODE): ' + file_name + '\n', end='') diff --git a/tuneease/pipeline/key_chord.py b/tuneease/pipeline/key_chord.py index bb13362..0fdbbc5 100644 --- a/tuneease/pipeline/key_chord.py +++ b/tuneease/pipeline/key_chord.py @@ -11,16 +11,17 @@ class TokenHelper: empty_index = None def __init__(self) -> None: - config = Config().config - with open(config['solver']['vocab_path'],'r') as f: - tokens = f.readlines() + if not bool(len(self.tokens_to_ids)): + config = Config().config + with open(config['solver']['vocab_path'],'r') as f: + tokens = f.readlines() - for id, token in enumerate(tokens): - token, freq = token.strip().split('\t') - self.tokens_to_ids[token] = id - self.ids_to_tokens = list(self.tokens_to_ids.keys()) - self.pad_index = self.tokens_to_ids[''] - self.empty_index = len(self.ids_to_tokens) + for id, token in enumerate(tokens): + token, freq = token.strip().split('\t') + self.tokens_to_ids[token] = id + self.ids_to_tokens = list(self.tokens_to_ids.keys()) + self.pad_index = self.tokens_to_ids[''] + self.empty_index = len(self.ids_to_tokens) class KeyChordDetails: pos_in_bar = None @@ -39,20 +40,21 @@ class KeyChordDetails: figure_size = None def __init__(self) -> None: - magenta = Magenta() - self.pos_in_bar = self.beat_note_factor * self.max_notes_per_bar * self.pos_resolution - self.figure_size = self.bar_max * self.pos_in_bar - self.key_profile_path = os.path.join(os.path.dirname(os.path.dirname(os.path.abspath(__file__))), "getmusic", "utils", "key_profile.pickle") - self.key_profile = pickle.load(open(self.key_profile_path, 'rb')) + if not bool(self.pos_in_bar): + magenta = Magenta() + self.pos_in_bar = self.beat_note_factor * self.max_notes_per_bar * self.pos_resolution + self.figure_size = self.bar_max * self.pos_in_bar + self.key_profile_path = os.path.join(os.path.dirname(os.path.dirname(os.path.abspath(__file__))), "getmusic", "utils", "key_profile.pickle") + self.key_profile = pickle.load(open(self.key_profile_path, 'rb')) - self.chord_pitch_out_of_key_prob = 0.01 - self.key_change_prob = 0.001 - self.chord_change_prob = 0.5 - self.key_chord_distribution = magenta._key_chord_distribution( - chord_pitch_out_of_key_prob=self.chord_pitch_out_of_key_prob) - self.key_chord_loglik = np.log(self.key_chord_distribution) - self.key_chord_transition_distribution = magenta._key_chord_transition_distribution( - self.key_chord_distribution, - key_change_prob=self.key_change_prob, - chord_change_prob=self.chord_change_prob) - self.key_chord_transition_loglik = np.log(self.key_chord_transition_distribution) + self.chord_pitch_out_of_key_prob = 0.01 + self.key_change_prob = 0.001 + self.chord_change_prob = 0.5 + self.key_chord_distribution = magenta._key_chord_distribution( + chord_pitch_out_of_key_prob=self.chord_pitch_out_of_key_prob) + self.key_chord_loglik = np.log(self.key_chord_distribution) + self.key_chord_transition_distribution = magenta._key_chord_transition_distribution( + self.key_chord_distribution, + key_change_prob=self.key_change_prob, + chord_change_prob=self.chord_change_prob) + self.key_chord_transition_loglik = np.log(self.key_chord_transition_distribution) diff --git a/tuneease/preprocess/binarize.py b/tuneease/preprocess/binarize.py index ac9b9e8..fb79905 100644 --- a/tuneease/preprocess/binarize.py +++ b/tuneease/preprocess/binarize.py @@ -1,13 +1,13 @@ +from ..getmusic.data.indexed_datasets import IndexedDatasetBuilder +from ..pipeline.presets import prog_to_abrv, inst_to_row, RootKinds +from ..pipeline.key_chord import TokenHelper, KeyChordDetails + import multiprocessing as mp -import random from tqdm import tqdm import numpy as np import sys import os sys.path.append('/'.join(os.path.abspath(__file__).split('/')[:-2])) -from ..getmusic.data.indexed_datasets import IndexedDatasetBuilder -from ..pipeline.presets import prog_to_abrv, inst_to_row, RootKinds -from ..pipeline.key_chord import TokenHelper, KeyChordDetails def oct_to_rep(line): @@ -40,7 +40,7 @@ def oct_to_rep(line): chord_list = [] - datum = t_h.t_h.pad_index * np.ones([14, 1 + k_c_d.figure_size],dtype=float) + datum = t_h.pad_index * np.ones([14, 1 + k_c_d.figure_size],dtype=float) idx = 0 while idx != len(encoding) - 1: diff --git a/tuneease/preprocess/to_oct.py b/tuneease/preprocess/to_oct.py index 200468e..568e41f 100644 --- a/tuneease/preprocess/to_oct.py +++ b/tuneease/preprocess/to_oct.py @@ -5,7 +5,7 @@ from multiprocessing import Pool import os import sys -from ..getmusic.utils.midi_config import bar_max, max_notes_per_bar, beat_note_factor, pos_resolution, max_inst, max_pitch, duration_max, max_tempo, max_velocity, sample_len_max, pool_num +from ..getmusic.utils.midi_config import BAR_MAX, MAX_NOTES_PER_BAR, BEAT_NOTE_FACTOR, POS_RESOLUTION, MAX_INST, MAX_PITCH, DURATION_MAX, MAX_TEMPO, MAX_VELOCITY, SAMPLE_LEN_MAX, POOL_NUM from ..pipeline.encoding_helpers import TS, v2e, b2e from ..pipeline.file import F sys.path.append('/'.join(os.path.abspath(__file__).split('/')[:-2])) @@ -19,23 +19,23 @@ def gen_dictionary(file_name): ts = TS() num = 0 with open(file_name, 'w') as f: - for j in range(bar_max): + for j in range(BAR_MAX): print('<0-{}>'.format(j), num, file=f) - for j in range(beat_note_factor * max_notes_per_bar * pos_resolution): + for j in range(BEAT_NOTE_FACTOR * MAX_NOTES_PER_BAR * POS_RESOLUTION): print('<1-{}>'.format(j), num, file=f) - for j in range(max_inst + 1 + 1): + for j in range(MAX_INST + 1 + 1): # max_inst + 1 for percussion print('<2-{}>'.format(j), num, file=f) - for j in range(2 * max_pitch + 1 + 1): + for j in range(2 * MAX_PITCH + 1 + 1): # max_pitch + 1 ~ 2 * max_pitch + 1 for percussion print('<3-{}>'.format(j), num, file=f) - for j in range(duration_max * pos_resolution): + for j in range(DURATION_MAX * POS_RESOLUTION): print('<4-{}>'.format(j), num, file=f) - for j in range(v2e(max_velocity) + 1): + for j in range(v2e(MAX_VELOCITY) + 1): print('<5-{}>'.format(j), num, file=f) for j in range(len(ts.ts_list)): print('<6-{}>'.format(j), num, file=f) - for j in range(b2e(max_tempo) + 1): + for j in range(b2e(MAX_TEMPO) + 1): print('<7-{}>'.format(j), num, file=f) def G(file_name): @@ -58,7 +58,7 @@ def encoding_to_str(e): p = 0 tokens_per_note = 8 return ' '.join((['<{}-{}>'.format(j, k if j > 0 else k + bar_index_offset) for i in e[p: p +\ - sample_len_max] if i[0] + bar_index_offset < bar_max for j, k in enumerate(i)])) # 8 - 1 for append_eos functionality of binarizer in fairseq + SAMPLE_LEN_MAX] if i[0] + bar_index_offset < BAR_MAX for j, k in enumerate(i)])) # 8 - 1 for append_eos functionality of binarizer in fairseq if __name__ == '__main__': data_path = sys.argv[1] @@ -80,7 +80,7 @@ def encoding_to_str(e): total_file_cnt = len(file_list) file_list_split = file_list output_file = '{}/oct.txt'.format(prefix) - with Pool(pool_num) as p: + with Pool(POOL_NUM) as p: result = list(p.imap_unordered(G, file_list_split)) all_cnt += sum((1 if i is not None else 0 for i in result)) ok_cnt += sum((1 if i is True else 0 for i in result)) diff --git a/tuneease/tests/memory.py b/tuneease/tests/memory.py new file mode 100644 index 0000000..c2efe68 --- /dev/null +++ b/tuneease/tests/memory.py @@ -0,0 +1,11 @@ +# pip install memory_profiler +# python -m memory_profiler os.path.abspath(__file__) --checkpoint_path +# from memory_profiler import profile +# from tuneease.tuneease import tuneease + +# @profile +# def main(): +# tuneease() + +# if __name__ == "__main__": +# main() \ No newline at end of file diff --git a/tuneease/tests/test_logger.py b/tuneease/tests/test_logger.py index 1006196..b9a8561 100644 --- a/tuneease/tests/test_logger.py +++ b/tuneease/tests/test_logger.py @@ -11,4 +11,4 @@ def test_ServerLogger_get(): server_logger = ServerLogger('test.log') logger = server_logger.get() assert logger is server_logger.logger - assert isinstance(logger, logging.Logger) \ No newline at end of file + assert isinstance(logger, logging.Logger) diff --git a/tuneease/tests/test_pathutility.py b/tuneease/tests/test_pathutility.py index ebc700a..e82dd8c 100644 --- a/tuneease/tests/test_pathutility.py +++ b/tuneease/tests/test_pathutility.py @@ -13,5 +13,5 @@ def test_path_utility_project_directory(): def test_path_utility_museScore_path(): path_util = PathUtility() - museScore_path = path_util.musescore_path() - assert os.path.exists(museScore_path) \ No newline at end of file + musescore_path = path_util.musescore_path() + assert os.path.exists(musescore_path) diff --git a/tuneease/tests/test_tuneease.py b/tuneease/tests/test_tuneease.py index d3846be..36f6721 100644 --- a/tuneease/tests/test_tuneease.py +++ b/tuneease/tests/test_tuneease.py @@ -26,7 +26,7 @@ def get_converter(): def test_tuneease_init(get_tuneease): tuneease = get_tuneease assert isinstance(tuneease, TuneEase) - + def test_convert(get_tuneease, get_music_file, get_converter): converter = get_converter music_file = get_music_file @@ -47,7 +47,7 @@ def test_number(get_tuneease, get_music_file): tuner = get_tuneease output_filepath = tuner.number(music_file) assert os.path.exists(output_filepath) - + def test_random_score(get_tuneease): tuneease = get_tuneease random_score_file = tuneease.random_score() @@ -60,4 +60,4 @@ def test_generate(get_tuneease): "content-bass": "True", } output_filepath = tuneease.generate(test_param) - assert os.path.exists(output_filepath) + assert os.path.exists(output_filepath)