diff --git a/.github/workflows/main.yml b/.github/workflows/main.yml index 0551aa2..995c22d 100644 --- a/.github/workflows/main.yml +++ b/.github/workflows/main.yml @@ -35,10 +35,36 @@ jobs: name: dist path: dist/ - publish_wheel: + test_wheel: runs-on: ubuntu-latest needs: [build_wheel] steps: + - name: Checkout github repo + uses: actions/checkout@v4 + - name: Checkout submodules + run: git submodule update --init --recursive + - name: Set up Python 3.11 + uses: actions/setup-python@v4 + with: + python-version: 3.11 + architecture: 'x64' + - uses: actions/download-artifact@v3 + with: + name: dist + path: dist/ + - name: Test the wheel + shell: bash {0} + run: | + pip install dist/gaga_phsp-*-py3-none-any.whl + cd tests + mkdir pth + python test001_non_cond.py + python test002_cond.py + + publish_wheel: + runs-on: ubuntu-latest + needs: [test_wheel] + steps: - name: Checkout github repo uses: actions/checkout@v4 - name: Checkout submodules @@ -49,7 +75,7 @@ jobs: path: dist/ - name: Publish to PyPI if: github.event_name == 'push' && startsWith(github.event.ref, 'refs/tags/') - uses: pypa/gh-action-pypi-publish@release/v1 + uses: pypa/gh-action-pypi-publish@master with: user: __token__ password: ${{ secrets.PYPI }} diff --git a/.gitignore b/.gitignore index a622154..0590d90 100644 --- a/.gitignore +++ b/.gitignore @@ -25,6 +25,9 @@ aa.png /tests/npy/*.npy /tests/*pth /tests/*pt +/tests/*png +/tests/*old* + /save/ @@ -38,6 +41,7 @@ aa.png /tests/a.pdf /tests/jz/ /tests/output/ +/tests/pth_oct_2022 *OLD* /gaga/gaga_helpers_pet_before_dw_change.py diff --git a/bin/gaga_garf_generate_img b/bin/gaga_garf_generate_img deleted file mode 100755 index 844bc35..0000000 --- a/bin/gaga_garf_generate_img +++ /dev/null @@ -1,144 +0,0 @@ -#!/usr/bin/env python3 -# -*- coding: utf-8 -*- - -import logging -import os -import time -import itk -import click -import gatetools as gt -from tqdm import tqdm -import gaga_phsp as gaga -import garf - -logger = logging.getLogger(__name__) - -CONTEXT_SETTINGS = dict(help_option_names=['-h', '--help']) - - -@click.command(context_settings=CONTEXT_SETTINGS) -@click.argument('gan_pth_filename') -@click.argument('garf_pth_filename') -@click.option('--n', '-n', default=1e5, help='Number of samples to generate') -@click.option('--output', '-o', required=True, help='Output filename.mhd ') -@click.option('--radius', '-r', default=float(180), help='Radius in mm') -@click.option('--start_angle', default=float(0), help='Starting angle (in deg)') -@click.option('--stop_angle', default=float(1), help='Stop angle (in deg)') -@click.option('--step_angle', default=float(1), help='Step angle (in deg)') -@click.option('--scale', '-s', default=float(1), help='Scale the final image by s. ') -@click.option('--debug_projection', default=False, is_flag=True, - help='Debug: dump npy file with U,V,theta,phi,E on the plane') -@click.option('--sigma', default=False, is_flag=True, - help='Compute and dump sigma (uncertainty) image') -@gt.add_options(gt.common_options) -def gaga_garf_generate_img(gan_pth_filename, garf_pth_filename, - radius, n, scale, - start_angle, stop_angle, step_angle, - output, debug_projection, sigma, - **kwargs): - """ - \b - Simulation of a SPECT image: - - input particles are generated from a GAN (gaga-phsp) - - detector plane use ARF (garf) to create the image - - \b - : input GAN-PHSP PTH file (.pth) - : input GARF PTH file (.pth) - """ - - # logger - gt.logging_conf(**kwargs) - - # input number of events - n = int(n) - - # load gan pth - logger.info(f'Reading GAN-PHSP from {gan_pth_filename}') - gan_params, G, D, optim, dtypef = gaga.load(gan_pth_filename) - - # load garf pth - logger.info(f'Reading GARF from {gan_pth_filename}') - garf_nn, garf_model = garf.load_nn(garf_pth_filename, verbose=False) - - # initialisation - batch_size = 4e6 - t0 = time.time() - - # initialisation plane - print('FIXME : image_plane_size_mm') - image_plane_size_mm = [576, 446] - - # initialisation garf - garf_param = {} - garf_param['gpu_batch_size'] = int(4e5) # float(gpu_batch_size)) - garf_param['size'] = 128 - garf_param['spacing'] = 4.41806 - garf_param['length'] = 99 - garf_param['N_scale'] = scale - garf_param['N_dataset'] = n - size_mm = garf_param['size'] * garf_param['spacing'] / 2.0 - - # initialisation - gan_batch_size = 2e5 - - # loop over angles - angle = start_angle - bar_n = (stop_angle - start_angle) / step_angle * n - pbar = tqdm(total=bar_n) - a = 0 - p = {"gan_params": gan_params, - "G": G, - "batch_size": batch_size, - "gan_batch_size": gan_batch_size, - # "plane": plane, - "image_plane_size_mm": image_plane_size_mm, - "debug": False, #### FIXME - "garf_nn": garf_nn, - "garf_model": garf_model, - "pbar": pbar, - "n": n, - "garf_param": garf_param} - - while angle < stop_angle: - tqdm.write(f'Angle {angle} deg') - ev = 0 - - plane = gaga.init_plane(batch_size, angle=angle, radius=radius) - p["plane"] = plane - image, sq_image = gaga.gaga_garf_generate_image(p) - - # save image - b, extension = os.path.splitext(output) - out = f'{b}_{str(angle).zfill(5)}{extension}' - itk.imwrite(image, out) - - # save image - sout = f'{b}_{str(angle).zfill(5)}-Squared{extension}' - itk.imwrite(sq_image, sout) - - # compute sigma if needed - if sigma: - filenames = [] - sfilenames = [] - filenames.append(out) - sfilenames.append(sout) - sigout = f'{b}_{str(angle).zfill(5)}_sigma{extension}' - nevents = n - sigma_flag = True - threshold = 0 - print(filenames, sfilenames, nevents) - uncertainty, m, nb = gt.image_uncertainty_by_slice(filenames, sfilenames, nevents, sigma_flag, threshold) - itk.imwrite(uncertainty, sigout) - - ev += batch_size - # pbar.update(batch_size) - angle += step_angle - a = a + 1 - - pbar.close() - - -# -------------------------------------------------------------------------- -if __name__ == '__main__': - gaga_garf_generate_img() diff --git a/bin/gaga_gauss_cond_test b/bin/gaga_gauss_cond_test index 0bd472c..75de8ea 100755 --- a/bin/gaga_gauss_cond_test +++ b/bin/gaga_gauss_cond_test @@ -1,10 +1,8 @@ #!/usr/bin/env python3 # -*- coding: utf-8 -*- -from shutil import copyfile import click from matplotlib import pyplot as plt -from scipy.stats import gaussian_kde from gatetools import phsp import numpy as np import gaga_phsp as gaga diff --git a/bin/gaga_gauss_plot b/bin/gaga_gauss_plot index b1e49e9..528dc7a 100755 --- a/bin/gaga_gauss_plot +++ b/bin/gaga_gauss_plot @@ -8,17 +8,24 @@ from matplotlib import pyplot as plt import gaga_phsp as gaga from scipy.stats import gaussian_kde -CONTEXT_SETTINGS = dict(help_option_names=['-h', '--help']) +CONTEXT_SETTINGS = dict(help_option_names=["-h", "--help"]) @click.command(context_settings=CONTEXT_SETTINGS) -@click.argument('phsp_filename') -@click.argument('pth_filename') -@click.option('--n', '-n', default=1e4, help='Number of samples to get from the phsp') -@click.option('--m', '-m', default=1e4, help='Number of samples to generate from the GAN') -@click.option('-x', default=float(1), help='Condition x') -@click.option('-y', default=float(1), help='Condition y') -@click.option('--epoch', '-e', default=-1, help='Load the G net at the given epoch (-1 for last stored epoch)') +@click.argument("phsp_filename") +@click.argument("pth_filename") +@click.option("--n", "-n", default=1e4, help="Number of samples to get from the phsp") +@click.option( + "--m", "-m", default=1e4, help="Number of samples to generate from the GAN" +) +@click.option("-x", default=float(1), help="Condition x") +@click.option("-y", default=float(1), help="Condition y") +@click.option( + "--epoch", + "-e", + default=-1, + help="Load the G net at the given epoch (-1 for last stored epoch)", +) def gaga_gauss_plot(phsp_filename, pth_filename, n, m, epoch, x, y): """ \b @@ -41,42 +48,44 @@ def gaga_gauss_plot(phsp_filename, pth_filename, n, m, epoch, x, y): # generate samples with condition cond = None - if len(params['cond_keys']) > 0: + if len(params["cond_keys"]) > 0: condx = np.ones(m) * x condy = np.ones(m) * y print(condx.shape, condy.shape) cond = np.column_stack((condx, condy)) print(cond.shape) - fake = gaga.generate_samples2(params, G, D, m, m, False, True, cond=cond) + fake = gaga.generate_samples3(params, G, m, cond=cond) + else: + fake = gaga.generate_samples_non_cond(params, G, m, m, False, True) # get 2D points x_ref = real[:, 0] y_ref = real[:, 1] x = fake[:, 0] y = fake[:, 1] - print('ref shape', x_ref.shape, y_ref.shape) - print('gan shape', x.shape, y.shape) + print("ref shape", x_ref.shape, y_ref.shape) + print("gan shape", x.shape, y.shape) - print('ref y min max', y_ref.min(), y_ref.max()) - print('ref x min max', x_ref.min(), x_ref.max()) + print("ref y min max", y_ref.min(), y_ref.max()) + print("ref x min max", x_ref.min(), x_ref.max()) - print('gan y min max', y.min(), y.max()) - print('gan x min max', x.min(), x.max()) + print("gan y min max", y.min(), y.max()) + print("gan x min max", x.min(), x.max()) # plot fig, ax = plt.subplots(1, 1, figsize=(20, 10)) a = ax - a.scatter(x_ref, y_ref, marker='.', s=0.1) - a.scatter(x, y, marker='.', s=0.1) - a.axis('equal') + a.scatter(x_ref, y_ref, marker=".", s=0.1) + a.scatter(x, y, marker=".", s=0.1) + a.axis("equal") plt.title(pth_filename) - f = f'cond.png' + f = f"cond.png" print(f) plt.savefig(f) # -------------------------------------------------------------------------- -if __name__ == '__main__': +if __name__ == "__main__": gaga_gauss_plot() diff --git a/bin/gaga_generate b/bin/gaga_generate index 787fea9..8d1fa57 100755 --- a/bin/gaga_generate +++ b/bin/gaga_generate @@ -7,20 +7,23 @@ import gatetools.phsp as phsp import os import time -CONTEXT_SETTINGS = dict(help_option_names=['-h', '--help']) +CONTEXT_SETTINGS = dict(help_option_names=["-h", "--help"]) @click.command(context_settings=CONTEXT_SETTINGS) -@click.argument('pth_filename') -@click.option('--n', '-n', default='1e4', help='Number of samples to generate') -@click.option('--output', '-o', default='AUTO', help='If AUTO, use pth_filename.npy') -@click.option('--output_folder', '-f', default=None, help='Output folder') -@click.option('--toggle/--no-toggle', default=False, help='Convert XY to angle') -@click.option('--epoch', default=-1, help='Use G at this epoch') -@click.option('--radius', default=350, help='When convert angle, need the radius (in mm)') -@click.option('--cond_phsp', '-c', default=None, help='Conditional phsp') -def gaga_generate(pth_filename, n, output, output_folder, toggle, - radius, epoch, cond_phsp): +@click.argument("pth_filename") +@click.option("--n", "-n", default="1e4", help="Number of samples to generate") +@click.option("--output", "-o", default="AUTO", help="If AUTO, use pth_filename.npy") +@click.option("--output_folder", "-f", default=None, help="Output folder") +@click.option("--toggle/--no-toggle", default=False, help="Convert XY to angle") +@click.option("--epoch", default=-1, help="Use G at this epoch") +@click.option( + "--radius", default=350, help="When convert angle, need the radius (in mm)" +) +@click.option("--cond_phsp", "-c", default=None, help="Conditional phsp") +def gaga_generate( + pth_filename, n, output, output_folder, toggle, radius, epoch, cond_phsp +): """ Generate a PHSP from a (trained) GAN @@ -32,27 +35,32 @@ def gaga_generate(pth_filename, n, output, output_folder, toggle, n = int(float(n)) # load pth - params, G, D, optim, dtypef = gaga.load(pth_filename, 'auto', verbose=False, epoch=epoch) - f_keys = list(params['keys_list']) + params, G, D, optim, dtypef = gaga.load( + pth_filename, "auto", verbose=False, epoch=epoch + ) + f_keys = list(params["keys_list"]) # cond ? cond_data = None if cond_phsp is not None: - cond_keys = params['cond_keys'] - print(f'Conditional keys {cond_keys}') + cond_keys = params["cond_keys"] + print(f"Conditional keys {cond_keys}") cond_data, cond_read_keys, m = phsp.load(cond_phsp, nmax=n) cond_keys = phsp.str_keys_to_array_keys(cond_keys) cond_data = phsp.select_keys(cond_data, cond_read_keys, cond_keys) - print(f'Conditional keys {cond_keys} {cond_data.shape}') + print(f"Conditional keys {cond_keys} {cond_data.shape}") # generate samples (b is batch size) b = 1e5 start = time.time() - fake = gaga.generate_samples2(params, G, D, n, b, False, True, cond=cond_data) + if cond_phsp is not None: + fake = gaga.generate_samples3(params, G, n, cond=cond_data) + else: + fake = gaga.generate_samples_non_cond(params, G, n, b, False, True) end = time.time() elapsed = end - start pps = n / elapsed - print(f'Timing: {end - start:0.1f} s PPS = {pps:0.0f}') + print(f"Timing: {end - start:0.1f} s PPS = {pps:0.0f}") # Keep X,Y or convert to toggle if toggle: @@ -68,20 +76,20 @@ def gaga_generate(pth_filename, n, output, output_folder, toggle, for k in cond_keys: keys.remove(k) - # write - if output == 'AUTO': - gp = params['penalty'] - gpw = params['penalty_weight'] + # write + if output == "AUTO": + gp = params["penalty"] + gpw = params["penalty_weight"] full_path = os.path.split(pth_filename) b, extension = os.path.splitext(full_path[1]) if not output_folder: - output_folder = '.' - output = f'{b}_{gp}_{gpw}_{init_n}.npy' + output_folder = "." + output = f"{b}_{gp}_{gpw}_{init_n}.npy" output = os.path.join(output_folder, output) print(output) phsp.save_npy(output, fake, keys) # -------------------------------------------------------------------------- -if __name__ == '__main__': +if __name__ == "__main__": gaga_generate() diff --git a/bin/gaga_info b/bin/gaga_info index ddb9963..77fae46 100755 --- a/bin/gaga_info +++ b/bin/gaga_info @@ -7,16 +7,23 @@ import torch from matplotlib import pyplot as plt import gaga_phsp as gaga -CONTEXT_SETTINGS = dict(help_option_names=['-h', '--help']) +CONTEXT_SETTINGS = dict(help_option_names=["-h", "--help"]) @click.command(context_settings=CONTEXT_SETTINGS) -@click.argument('pth_filename', nargs=-1) -@click.option('--plot/--no-plot', default=False) -@click.option('--add_energy', default=float(-1), - help='Add the key Ekine with the given value in the parameters of the pth file') -@click.option('--sfig/--no-sfig', is_flag=True, default=False, help='special plot for figure') -@click.option('--short', is_flag=True, default=False, help='one line summary of main parameters') +@click.argument("pth_filename", nargs=-1) +@click.option("--plot/--no-plot", default=False) +@click.option( + "--add_energy", + default=float(-1), + help="Add the key Ekine with the given value in the parameters of the pth file", +) +@click.option( + "--sfig/--no-sfig", is_flag=True, default=False, help="special plot for figure" +) +@click.option( + "--short", is_flag=True, default=False, help="one line summary of main parameters" +) def gaga_info(pth_filename, plot, add_energy, sfig, short): """ \b @@ -31,7 +38,7 @@ def gaga_info(pth_filename, plot, add_energy, sfig, short): fig, ax = plt.subplots(1, 1, figsize=(16, 5)) # or 1,3 for f in pth_filename: - params, G, D, optim, dtypef = gaga.load(f, fatal_on_unknown_keys=False) + params, G, D, optim = gaga.load(f, fatal_on_unknown_keys=False) if short: gaga.print_info_short(params, optim) @@ -50,24 +57,24 @@ def gaga_info(pth_filename, plot, add_energy, sfig, short): if plot: plt.tight_layout() - plt.savefig('a.pdf', dpi=fig.dpi) + plt.savefig("a.pdf", dpi=fig.dpi) plt.show() if add_energy != -1: if len(pth_filename) != 1: - print('Cannot add_energy to several pth_filename') + print("Cannot add_energy to several pth_filename") exit(0) f = pth_filename[0] - params['Ekine'] = add_energy - if params['current_gpu']: + params["Ekine"] = add_energy + if params["current_gpu"]: nn = torch.load(f) else: nn = torch.load(f, map_location=lambda storage, loc: storage) - nn['params'] = params - copyfile(f, f + '.save') + nn["params"] = params + copyfile(f, f + ".save") torch.save(nn, f) # -------------------------------------------------------------------------- -if __name__ == '__main__': +if __name__ == "__main__": gaga_info() diff --git a/bin/gaga_plot b/bin/gaga_plot index 1876347..55496b6 100755 --- a/bin/gaga_plot +++ b/bin/gaga_plot @@ -7,27 +7,55 @@ import numpy as np from matplotlib import pyplot as plt import gaga_phsp as gaga -CONTEXT_SETTINGS = dict(help_option_names=['-h', '--help']) +CONTEXT_SETTINGS = dict(help_option_names=["-h", "--help"]) @click.command(context_settings=CONTEXT_SETTINGS) -@click.argument('phsp_filename') -@click.argument('pth_filename') -@click.option('--n', '-n', default=1e4, help='Number of samples to generate') -@click.option('--nb_bins', '-b', default=int(200), help='Number of bins') -@click.option('--toggle/--no-toggle', default=False, help='convert angle to XY (DEBUG)') -@click.option('--quantile', '-q', default=float(0), help='Restrict histogram to quantile') -@click.option('--radius', default=350, help='When convert angle, need the radius (in mm)') -@click.option('--plot2d', - type=(str, str), - help='Add 2D plots (key1,key2), such as --plot2d X Ekine --plot2d X Y ', multiple=True) -@click.option('--epoch', '-e', default=-1, help='Load the G net at the given epoch (-1 for last stored epoch)') -@click.option('--no-title', is_flag=True, default=False) -@click.option('--output', '-o', type=str, help='Do not plot, only output a pdf with the given name') -@click.option('--cond_phsp', '-c', default=None, help='Conditional phsp') -def gaga_plot(phsp_filename, pth_filename, n, nb_bins, - toggle, radius, quantile, plot2d, epoch, - output, no_title, cond_phsp): +@click.argument("phsp_filename") +@click.argument("pth_filename") +@click.option("--n", "-n", default=1e4, help="Number of samples to generate") +@click.option("--nb_bins", "-b", default=int(200), help="Number of bins") +@click.option("--toggle/--no-toggle", default=False, help="convert angle to XY (DEBUG)") +@click.option( + "--quantile", "-q", default=float(0), help="Restrict histogram to quantile" +) +@click.option( + "--radius", default=350, help="When convert angle, need the radius (in mm)" +) +@click.option( + "--plot2d", + type=(str, str), + help="Add 2D plots (key1,key2), such as --plot2d X Ekine --plot2d X Y ", + multiple=True, +) +@click.option( + "--epoch", + "-e", + default=-1, + help="Load the G net at the given epoch (-1 for last stored epoch)", +) +@click.option("--no-title", is_flag=True, default=False) +@click.option( + "--output", + "-o", + type=str, + help="Do not plot, only output a pdf with the given name", +) +@click.option("--cond_phsp", "-c", default=None, help="Conditional phsp") +def gaga_plot( + phsp_filename, + pth_filename, + n, + nb_bins, + toggle, + radius, + quantile, + plot2d, + epoch, + output, + no_title, + cond_phsp, +): """ \b Plot marginal distributions from a GAN-PHSP @@ -49,22 +77,24 @@ def gaga_plot(phsp_filename, pth_filename, n, nb_bins, # load pth params, G, D, optim = gaga.load(pth_filename, epoch=epoch) - f_keys = params['keys'] + f_keys = params["keys"] if isinstance(f_keys, str): - f_keys = params['keys_list'] + f_keys = params["keys_list"] keys = f_keys.copy() # cond ? cond_data = None if not cond_phsp is None: - cond_keys = params['cond_keys'] + cond_keys = params["cond_keys"] cond_data, cond_read_keys, m = phsp.load(cond_phsp, nmax=n) cond_keys = phsp.str_keys_to_array_keys(cond_keys) cond_data = phsp.select_keys(cond_data, cond_read_keys, cond_keys) - print(f'Conditional keys {cond_keys} {cond_data.shape}') - - # generate samples - fake = gaga.generate_samples2(params, G, D, n, int(1e5), normalize=False, to_numpy=True, cond=cond_data) + print(f"Conditional keys {cond_keys} {cond_data.shape}") + fake = gaga.generate_samples3(params, G, n, cond=cond_data) + else: + fake = gaga.generate_samples_non_cond( + params, G, n, int(1e5), normalize=False, to_numpy=True + ) # add cond dimensions if not cond_phsp is None: @@ -102,31 +132,31 @@ def gaga_plot(phsp_filename, pth_filename, n, nb_bins, q1 = quantile q2 = 1.0 - quantile q[k] = (np.quantile(d, q1), np.quantile(d, q2)) - lab = '' + lab = "" if no_title: - lab = 'PHSP ' - gaga.fig_plot_marginal(real, k, keys, ax, i, nb_bins, 'g', q[k], lab) + lab = "PHSP " + gaga.fig_plot_marginal(real, k, keys, ax, i, nb_bins, "g", q[k], lab) i = i + 1 # plot all keys for fake data (same range) i = 0 for k in keys: if no_title: - lab = 'GAN ' - gaga.fig_plot_marginal(fake, k, keys, ax, i, nb_bins, 'r', q[k], lab) + lab = "GAN " + gaga.fig_plot_marginal(fake, k, keys, ax, i, nb_bins, "r", q[k], lab) i = i + 1 # plot 2D distribution if len(keys) > 1: starti = i for kk in keys_2d: - gaga.fig_plot_marginal_2d(real, kk[0], kk[1], keys, ax, i, nb_bins, 'g') + gaga.fig_plot_marginal_2d(real, kk[0], kk[1], keys, ax, i, nb_bins, "g") i = i + 1 # plot 2D distribution i = starti for kk in keys_2d: - gaga.fig_plot_marginal_2d(fake, kk[0], kk[1], keys, ax, i, nb_bins, 'r') + gaga.fig_plot_marginal_2d(fake, kk[0], kk[1], keys, ax, i, nb_bins, "r") i = i + 1 if False: @@ -151,5 +181,5 @@ def gaga_plot(phsp_filename, pth_filename, n, nb_bins, # -------------------------------------------------------------------------- -if __name__ == '__main__': +if __name__ == "__main__": gaga_plot() diff --git a/bin/gaga_train b/bin/gaga_train index b492401..1d65779 100755 --- a/bin/gaga_train +++ b/bin/gaga_train @@ -11,27 +11,57 @@ import gaga_phsp as gaga from box import Box import socket -CONTEXT_SETTINGS = dict(help_option_names=['-h', '--help']) +CONTEXT_SETTINGS = dict(help_option_names=["-h", "--help"]) @click.command(context_settings=CONTEXT_SETTINGS) -@click.argument('phsp_filename', type=click.Path(exists=True, file_okay=True, dir_okay=False)) -@click.argument('json_filename', type=click.Path(exists=True, file_okay=True, dir_okay=False)) -@click.option('--output', '-o', help='Output filename, default = automatic name', default='auto') -@click.option('--output_folder', '-f', help='Output folder (ignored if output is not "auto")', default='.') -@click.option('--progress-bar/--no-progress-bar', default=True) -@click.option('--user_param_str', '-ps', - help='overwrite str parameter of the json file', - multiple=True, type=(str, str)) -@click.option('--user_param', '-p', - help='overwrite numeric parameter of the json file', - multiple=True, type=(str, float)) -@click.option('--user_param_int', '-pi', - help='overwrite numeric int parameter of the json file', - multiple=True, type=(str, int)) -def gaga_train(phsp_filename, json_filename, - output, output_folder, - progress_bar, user_param_str, user_param, user_param_int): +@click.argument( + "phsp_filename", type=click.Path(exists=True, file_okay=True, dir_okay=False) +) +@click.argument( + "json_filename", type=click.Path(exists=True, file_okay=True, dir_okay=False) +) +@click.option( + "--output", "-o", help="Output filename, default = automatic name", default="auto" +) +@click.option( + "--output_folder", + "-f", + help='Output folder (ignored if output is not "auto")', + default=".", +) +@click.option("--progress-bar/--no-progress-bar", default=True) +@click.option( + "--user_param_str", + "-ps", + help="overwrite str parameter of the json file", + multiple=True, + type=(str, str), +) +@click.option( + "--user_param", + "-p", + help="overwrite numeric parameter of the json file", + multiple=True, + type=(str, float), +) +@click.option( + "--user_param_int", + "-pi", + help="overwrite numeric int parameter of the json file", + multiple=True, + type=(str, int), +) +def gaga_train( + phsp_filename, + json_filename, + output, + output_folder, + progress_bar, + user_param_str, + user_param, + user_param_int, +): """ \b Train GAN to learn a PHSP (Phase Space File) @@ -78,16 +108,17 @@ def gaga_train(phsp_filename, json_filename, # print parameters for e in params: - if (e[0] != '#'): - print(f' {e:32s} {params[e]}') + if e[0] != "#": + print(f" {e:32s} {params[e]}") # build the model - print(Fore.CYAN + 'Building the GAN model ...' + Style.RESET_ALL) + print(Fore.CYAN + "Building the GAN model ..." + Style.RESET_ALL) gan = gaga.Gan(params) # train - print(Fore.CYAN + 'Start training ...' + Style.RESET_ALL) - model = gan.train(x) + print(Fore.CYAN + "Start training ..." + Style.RESET_ALL) + # model = gan.train(x) + model = gan.train2(x) # stop timer stop = datetime.datetime.now() @@ -99,9 +130,9 @@ def gaga_train(phsp_filename, json_filename, # save output (params is in the model) gan.save(model, params.output_filename) - print(Fore.CYAN + f'Training done: {params.output_filename}') + print(Fore.CYAN + f"Training done: {params.output_filename}") # -------------------------------------------------------------------------- -if __name__ == '__main__': +if __name__ == "__main__": gaga_train() diff --git a/bin/gaga_wasserstein b/bin/gaga_wasserstein deleted file mode 100755 index 67b33ae..0000000 --- a/bin/gaga_wasserstein +++ /dev/null @@ -1,78 +0,0 @@ -#!/usr/bin/env python3 -# -*- coding: utf-8 -*- - -import click -import gaga_phsp as gaga -import gatetools.phsp as phsp -from torch.autograd import Variable -import torch -import numpy as np - -CONTEXT_SETTINGS = dict(help_option_names=['-h', '--help']) -@click.command(context_settings=CONTEXT_SETTINGS) -@click.argument('phsp_filename') -@click.argument('pth_filename') -@click.option('--n', '-n', default=1e4, help='Number of samples to generate') -@click.option('--l', '-l', default=1e2, help='Number of projections') -@click.option('--p', '-p', default=1, help='Wasserstein distance power p=1 default') -@click.option('--keys', '-k', help='Plot the given keys (as a str list such that "X Y Z")', default='') -@click.option('--toggle/--no-toggle', default=False, help='Convert XY to angle') -@click.option('--radius', default=350, help='When convert angle, need the radius (in mm)') -@click.option('--normalize/--no-normalize', default=True, - help='Normalize dimensions') -def gaga_wasserstein(phsp_filename, pth_filename, n, l, p, keys, toggle, radius, normalize): - ''' - \b - Compute sliced Wasserstein between real and GAN generated distributions - - \b - : input phase space file PHSP file (.npy) - : input GAN PTH file (.pth) - ''' - - n = int(n) - - # load phsp - real, r_keys, m = phsp.load(phsp_filename, n) - - # user keys - keys = phsp.str_keys_to_array_keys(keys) - - # load pth - params, G, D, optim, dtypef = gaga.load(pth_filename) - f_keys = params['keys'] - if len(keys) == 0: - keys = f_keys - - # generate samples (do NOT normalize yet) - fake = gaga.generate_samples2(params, G, D, n, 1e5, False, True) - - # select keys - real = phsp.select_keys(real, r_keys, keys) - fake = phsp.select_keys(fake, f_keys, keys) - - # normalize for computing the wasserstein - if normalize: - x_mean = params['x_mean'] - x_std = params['x_std'] - f_keys = params['keys'] - x_mean, z = phsp.add_missing_angle(x_mean, f_keys, keys, radius) - x_std, z = phsp.add_missing_angle(x_std, f_keys, keys, radius) - x_mean = phsp.select_keys(x_mean, z, keys) - x_std = phsp.select_keys(x_std, z, keys) - real = (real-x_mean)/x_std - fake = (fake-x_mean)/x_std - - # convert to pytorch - fake = Variable(torch.from_numpy(fake)).type(dtypef) - real = Variable(torch.from_numpy(real)).type(dtypef) - - # distance - d = gaga.sliced_wasserstein(real, fake, l, p) - print(pth_filename, d, keys, n, l) - - -# -------------------------------------------------------------------------- -if __name__ == '__main__': - gaga_wasserstein() - diff --git a/gaga_phsp/__init__.py b/gaga_phsp/__init__.py index 2e85b96..a449b23 100644 --- a/gaga_phsp/__init__.py +++ b/gaga_phsp/__init__.py @@ -3,6 +3,7 @@ from .gaga_helpers_plot import * from .gaga_helpers_pet import * from .gaga_helpers_spect import * +from .gaga_helpers_tests import * from .gaga_functions import * from .gaga import * from .gaga_model import * diff --git a/gaga_phsp/gaga.py b/gaga_phsp/gaga.py index d7ac7a7..878cd27 100644 --- a/gaga_phsp/gaga.py +++ b/gaga_phsp/gaga.py @@ -1,9 +1,6 @@ import copy from torch.utils.data import DataLoader -import torch -from torch import Tensor from tqdm import tqdm -from .gaga_functions import * from .gaga_helpers import * import gaga_phsp from garf.helpers import get_gpu_device @@ -71,6 +68,7 @@ def init_model(self): self.D = start_D self.G = start_G try: + print(f'Loading last epoch: {start_optim["last_epoch"]}') self.params["start_epoch"] = start_optim["last_epoch"] except: self.params["start_epoch"] = start_optim["current_epoch"][-1] @@ -279,14 +277,14 @@ def init_optim_data(self): return optim def set_net_to_device(self, device): - print("Set model to ", device) + print("Set model to", device) self.G.to(device) self.D.to(device) # print('Set data to GPU') # real_labels and fake_labels are set to cuda before - print("Set optim to ", device) + print("Set optim to", device) self.criterion_dr.to(device) self.criterion_df.to(device) self.criterion_g.to(device) @@ -329,6 +327,9 @@ def train(self, x): self.x = x self.batch_size = self.params.batch_size + # why ? FIXME + self.x = x.astype(np.float32) + # init cuda/mps self.set_net_to_device(self.current_gpu_device) @@ -349,7 +350,8 @@ def train(self, x): loader = DataLoader( self.x, batch_size=batch_size, - num_workers=2, # no gain if larger than 2 (?) + # num_workers=2, # no gain if larger than 2 (?) + num_workers=1, # no gain if larger than 2 (?) # https://discuss.pytorch.org/t/data-loader-multiprocessing-slow-on-macos/131204/3 persistent_workers=True, pin_memory=True, @@ -545,6 +547,126 @@ def train(self, x): self.params.duration = str(stop - start) return optim + def train2(self, x): + """ + Train the GAN + """ + if "dataloader_num_workers" not in self.params: + self.params["dataloader_num_workers"] = 4 + if self.current_gpu_mode == "mps": + self.params["dataloader_num_workers"] = 1 + + # normalisation + print("Normalization") + x, x_mean, x_std = gaga_phsp.normalize_data(x) + self.params["x_mean"] = x_mean + self.params["x_std"] = x_std + + # main dataset + self.x = x + self.batch_size = self.params.batch_size + + # why ? FIXME + self.x = x.astype(np.float32) + self.total_n = len(self.x) + print(f"Total training dataset size = {self.total_n}") + + # init cuda/mps + self.set_net_to_device(self.current_gpu_device) + + # initialise the data structure that will store info during training + optim = self.init_optim_data() + self.optim = optim + + # init conditional + self.condn = len(self.params["cond_keys"]) + self.conditional = self.condn > 0 + if self.conditional: + print(f'Conditional : {self.params["cond_keys"]} ' + str(self.condn)) + + # Sampler + print(f"Dataloader num_workers={self.params['dataloader_num_workers']}") + batch_size = self.params["batch_size"] + loader = DataLoader( + self.x, + batch_size=batch_size, + num_workers=self.params["dataloader_num_workers"], + # https://discuss.pytorch.org/t/data-loader-multiprocessing-slow-on-macos/131204/3 + persistent_workers=True, + pin_memory=True, + # https://discuss.pytorch.org/t/what-is-the-disadvantage-of-using-pin-memory/1702/4 + shuffle=self.params["shuffle"], + # shuffle=False, # if false ~20% faster, seems identical + drop_last=True, # always keep batch_size elements + ) + data_iter = iter(loader) + + # Start training + self.params["end_epoch"] = self.params["start_epoch"] + self.params["epoch"] + print( + f"Epoch from {self.params['start_epoch']} to {self.params['end_epoch']} (total = {self.params['epoch']})" + ) + start = datetime.datetime.now() + pbar = tqdm( + total=self.params["epoch"] * self.total_n, + disable=not self.params["progress_bar"], + ) + + epoch = self.params["start_epoch"] + # for epoch in range(self.params["start_epoch"], self.params["end_epoch"]): + condx = None + while epoch < self.params["end_epoch"]: + # D + for p in self.D.parameters(): + p.requires_grad = True + cont_epoch, condx = self.update_D(data_iter, loader) + if not cont_epoch: + epoch += 1 + pbar.set_postfix( + epoch=epoch, + d_loss=self.d_loss.data.item(), + g_loss=self.g_loss.data.item(), + ) + # sometimes: print and store the full model + self.epoch_dump(epoch) + self.epoch_store(epoch) + # store loss every epoch + self.optim["d_loss_real"].append(self.d_real_loss.data.item()) + self.optim["d_loss_fake"].append(self.d_fake_loss.data.item()) + self.optim["d_loss"].append(self.d_loss.data.item()) + self.optim["g_loss"].append(self.g_loss.data.item()) + # scheduler + if self.is_scheduler_enabled: + self.d_scheduler.step() + self.g_scheduler.step() + + # G + for p in self.D.parameters(): + p.requires_grad = False + self.update_G(condx) + + # housekeeping (to not accumulate gradient) + # zero_grad clears old gradients from the last step + # (otherwise you’d just accumulate the gradients from all loss.backward() calls). + # https://stackoverflow.com/questions/48001598/why-do-we-need-to-call-zero-grad-in-pytorch + self.D.zero_grad() + self.G.zero_grad() + + # update loop + n = self.batch_size * self.params["d_nb_update"] + pbar.update(n) + + # end of training + pbar.close() + stop = datetime.datetime.now() + optim["last_epoch"] = self.params["end_epoch"] + print("Start time = ", start.strftime(gaga_phsp.date_format)) + print("End time = ", stop.strftime(gaga_phsp.date_format)) + print("Duration time = ", (stop - start)) + self.params.Duration = str(stop - start) + self.params.duration = str(stop - start) + return optim + def save(self, optim, filename): """ Save the model @@ -565,14 +687,13 @@ def epoch_dump(self, epoch): try: n = self.params["epoch_dump"] except: - n = 500 + n = 100 if epoch % n != 0: return tqdm.write( - "Epoch %d d_loss: %.5f g_loss: %.5f d_real_loss: %.5f d_fake_loss: %.5f" + f"Epoch {epoch} d_loss: %.5f g_loss: %.5f d_real_loss: %.5f d_fake_loss: %.5f" % ( - epoch, self.d_loss.data.item(), self.g_loss.data.item(), self.d_real_loss.data.item(), @@ -592,7 +713,114 @@ def epoch_store(self, epoch): if epoch % n != 0: return - state = copy.deepcopy(self.G.state_dict()) self.optim["g_model_state"].append(state) self.optim["current_epoch"].append(epoch) + + def update_D(self, data_iter, loader): + cont_epoch = True + batch_size = self.params.batch_size + z_dim = self.params["z_dim"] + nx = self.params["x_dim"] + condx = None + + for i in range(self.params["d_nb_update"]): + # grad + self.D.zero_grad() + + # load input data (and determine if the data pool is empty) + x, ce = self.get_next_input_data(data_iter, loader) + cont_epoch = cont_epoch and ce + + # get decision from the discriminator + d_real_decision = self.D(x) + + # generate z noise (latent) + z = Tensor(self.z_rand(batch_size, z_dim)).to(self.current_gpu_device) + + # concat conditional vector (if any) + if self.conditional: + condx = x[:, nx - self.condn : nx] + # z = torch.cat((z.float(), condx.float()), dim=1) + z = torch.cat((z, condx), dim=1) + + # generate fake data + # (detach to avoid training G on these labels) + d_fake_data = self.G(z).detach() # FIXME detach ? + + # concat conditional vector (if any) + if self.conditional: + # d_fake_data = torch.cat((d_fake_data.float(), condx.float()), dim=1) + d_fake_data = torch.cat((d_fake_data, condx), dim=1) + + # get the fake decision on the fake data + d_fake_decision = self.D(d_fake_data) + + # set penalty + penalty = self.penalty_weight * self.penalty_fct(self, x, d_fake_data) + + # compute loss between decision on real and vector of ones (real_labels) + self.d_real_loss = self.criterion_dr(d_real_decision, self.real_labels) + + # compute loss between decision on fake and vector of zeros (fake_labels) + self.d_fake_loss = self.criterion_df(d_fake_decision, self.fake_labels) + + # backward + self.d_real_loss.backward() + self.d_fake_loss.backward() + if self.penalty_fct != gaga_phsp.zero_penalty: + penalty.backward() + + # sum of loss + self.d_loss = self.d_real_loss + self.d_fake_loss + penalty + + # optimizer + self.d_optimizer.step() + + return cont_epoch, condx + + def update_G(self, condx): + batch_size = self.params.batch_size + z_dim = self.params["z_dim"] + + for i in range(self.params["g_nb_update"]): + # required + self.G.zero_grad() + + # generate z noise (latent) + z = Tensor(self.z_rand(batch_size, z_dim)).to(self.current_gpu_device) + + # conditional + if self.conditional: + # z = torch.cat((z.float(), condx.float()), dim=1) + z = torch.cat((z, condx), dim=1) + + # generate the fake data + g_fake_data = self.G(z) + + # concat conditional vector (if any) + if self.conditional: + # g_fake_data = torch.cat((g_fake_data.float(), condx.float()), dim=1) + g_fake_data = torch.cat((g_fake_data, condx), dim=1) + + # get the fake decision + g_fake_decision = self.D(g_fake_data) + + self.g_loss = self.criterion_g(g_fake_decision, self.real_labels) + + # Backprop + Optimize + self.g_loss.backward(retain_graph=True) + self.g_optimizer.step() + + def get_next_input_data(self, data_iter, loader): + cont_epoch = True + # the input data + # https://github.com/pytorch/pytorch/issues/1917 + try: + data = next(data_iter) + except StopIteration: + data_iter = iter(loader) + data = next(data_iter) + cont_epoch = False + x = Tensor(data).to(self.current_gpu_device) + return x, cont_epoch diff --git a/gaga_phsp/gaga_functions.py b/gaga_phsp/gaga_functions.py index 672e108..df51002 100644 --- a/gaga_phsp/gaga_functions.py +++ b/gaga_phsp/gaga_functions.py @@ -1,7 +1,5 @@ import torch -import torch.nn as nn from torch.autograd import grad as torch_grad -from torch import Tensor import numpy as np """ diff --git a/gaga_phsp/gaga_helpers.py b/gaga_phsp/gaga_helpers.py index d0efa99..e41e830 100644 --- a/gaga_phsp/gaga_helpers.py +++ b/gaga_phsp/gaga_helpers.py @@ -1,19 +1,16 @@ -import numpy as np -import torch -from torch import Tensor import gaga_phsp as gaga -import datetime -import time -import garf from garf.helpers import get_gpu_device import gatetools.phsp as phsp +import numpy as np +import torch +from torch import Tensor from scipy.stats import entropy from scipy.spatial.transform import Rotation -import itk import logging import sys import os from box import Box, BoxList +import datetime logger = logging.getLogger(__name__) @@ -179,10 +176,15 @@ def check_input_params(params, fatal_on_unknown_keys=True): "RMSProp_d_alpha", "RMSProp_g_alpha", "GAN_model", + "d_weight_decay", + "g_weight_decay", "RMSProp_d_weight_decay", "RMSProp_g_weight_decay", "RMSProp_d_centered", "RMSProp_g_centered", + "dataloader_num_workers", + "beta_1", + "beta_2", ] for p in params: if p[0] == "#": @@ -332,9 +334,7 @@ def auto_output_filename(params, output, output_folder): params.output_filename = os.path.join(output_folder, output) -def load( - filename, gpu_mode="auto", verbose=False, epoch=-1, fatal_on_unknown_keys=True -): +def load(filename, gpu_mode="auto", epoch=-1, fatal_on_unknown_keys=True): """ Load a GAN-PHSP Output params = dict with all parameters @@ -483,16 +483,14 @@ def get_z_rand(params): return torch.randn -def generate_samples2( +def generate_samples_non_cond( params, G, - D, n, batch_size=-1, normalize=False, to_numpy=False, z=None, - cond=None, silence=False, ): # batch size -> if n is lower, batch size is n @@ -505,38 +503,18 @@ def generate_samples2( batch_size = int(n) # get z random (gauss or uniform) - z_rand = get_z_rand(params) - - # is this a conditional GAN ? - is_conditional = not cond is None + if z is None: + z_rand = get_z_rand(params) + else: + z_rand = z # normalize the input condition ncond = 0 - if is_conditional: - # normalize the conditional vector - xmean = params["x_mean"][0] - xstd = params["x_std"][0] - xn = params["x_dim"] - cn = len(params["cond_keys"]) - ncond = cn - # mean and std for cond only - xmeanc = xmean[xn - cn : xn] - xstdc = xstd[xn - cn : xn] - # mean and std for non cond - xmeannc = xmean[0 : xn - cn] - xstdnc = xstd[0 : xn - cn] - # normalize the condition - cond = (cond - xmeanc) / xstdc - else: - if len(params["cond_keys"]) > 0: - print( - f'Error : GAN is conditional, you should provide the condition: {params["cond_keys"]}' - ) - exit(0) - - langevin_latent_sampling_flag = False - if "langevin_latent_sampling" in params: - langevin_latent_sampling_flag = True + if len(params["cond_keys"]) > 0: + print( + f'Error : GAN is conditional, you should provide the condition: {params["cond_keys"]}' + ) + exit(0) m = 0 z_dim = params["z_dim"] @@ -544,7 +522,7 @@ def generate_samples2( device = params["current_gpu_device"] current_gpu_mode = params["current_gpu_mode"] rfake_dtype = np.float64 - if current_gpu_mode == 'mps': + if current_gpu_mode == "mps": rfake_dtype = np.float32 rfake = np.empty((0, x_dim - ncond), dtype=rfake_dtype) while m < n: @@ -560,25 +538,6 @@ def generate_samples2( # if None == z: z = Tensor(z_rand(current_gpu_batch_size, z_dim)).to(device) - # condition ? - if is_conditional: - if current_gpu_mode == "mps": - # print("With device mps (gpu), convert data to float32") - acond = cond[m : m + current_gpu_batch_size].astype(np.float32) - else: - acond = cond[m : m + current_gpu_batch_size] - condx = ( - # Tensor(torch.from_numpy(cond[m : m + current_gpu_batch_size])) - Tensor(torch.from_numpy(acond)) - .to(device) - .view(current_gpu_batch_size, cn) - ) - z = torch.cat((z.float(), condx.float()), dim=1) - - # FIXME test langevin - if langevin_latent_sampling_flag: - z = gaga.langevin_latent_sampling(G, D, params, z) - fake = G(z) # put back to cpu to allow concatenation fake = fake.cpu().data.numpy() @@ -589,10 +548,6 @@ def generate_samples2( if not normalize: x_mean = params["x_mean"] x_std = params["x_std"] - if is_conditional: - # do not consider the mean/std of the condition part - x_mean = xmeannc - x_std = xstdnc rfake = (rfake * x_std) + x_mean if to_numpy: @@ -601,55 +556,60 @@ def generate_samples2( return Tensor(torch.from_numpy(rfake)).to(device) -def generate_samples3(params, G, n, cond): +def generate_samples3(params, G, n, cond, to_numpy=True): """ - Like generate_samples2 but with less options, to see if it can be faster - - - batch size is managed elsewhere + Like generate_samples2 but with fewer options + FIXME : consider normalization in G ? """ + # ensure n is int + n = int(n) + # normalize the conditional vector xmean = params["x_mean"][0] xstd = params["x_std"][0] xn = params["x_dim"] cn = len(params["cond_keys"]) - ncond = cn + # mean and std for cond only xmeanc = xmean[xn - cn : xn] xstdc = xstd[xn - cn : xn] + # mean and std for non cond xmeannc = xmean[0 : xn - cn] xstdnc = xstd[0 : xn - cn] + # normalize the condition cond = (cond - xmeanc) / xstdc - m = 0 - z_dim = params["z_dim"] - x_dim = params["x_dim"] - rfake = np.empty((0, x_dim - ncond)) + # special case for gpu = mps (apple) + current_gpu_mode = params["current_gpu_mode"] + if current_gpu_mode == "mps": + cond = cond.astype(np.float32) - # (checking Z allow to reuse z for some special test case) - # if None == z: + # create the z input latent space device = params["current_gpu_device"] + z_dim = params["z_dim"] z = Tensor(torch.randn(n, z_dim)).to(device) - # condition ? - condx = Tensor(torch.from_numpy(cond[m : m + n])).to(device).view(n, cn) + # set condition to the device and concat to the z + condx = Tensor(torch.from_numpy(cond)).to(device) + condx = condx.view(n, cn) z = torch.cat((z.float(), condx.float()), dim=1) - # Go !!! + # Go ! fake = G(z) # put back to cpu to allow concatenation - fake = fake.cpu().data.numpy() - rfake = np.concatenate((rfake, fake), axis=0) + fake = fake.cpu().data.numpy() # FIXME # do not consider the mean/std of the condition part - x_mean = xmeannc - x_std = xstdnc - rfake = (rfake * x_std) + x_mean + fake = (fake * xstdnc) + xmeannc - return rfake + if to_numpy: + return fake + + return Tensor(torch.from_numpy(fake)).to(device) def Jensen_Shannon_divergence(x, y, bins, margin=0): @@ -714,65 +674,34 @@ def wasserstein1D(x, y, p=1): return torch.sum(torch.pow(torch.abs(z), p)) / len(z) -def init_plane(n, angle, radius): +def init_plane3(n, angle, radius, spect_table_shift_mm): """ plane_U, plane_V, plane_point, plane_normal """ - n = int(n) - logger.info(f"Initialisation of plane with radius {radius} ") - plane_U = np.array([1, 0, 0]) - plane_V = np.array([0, 1, 0]) - r = Rotation.from_euler("y", angle, degrees=True) - plane_U = r.apply(plane_U) - plane_V = r.apply(plane_V) + # , spect_table_shift_mm should be 2D ? - # normal vector is the cross product of two direction vectors on the plane - plane_normal = np.cross(plane_U, plane_V) - plane_normal = np.array([plane_normal] * n) - - center = np.array([0, 0, -radius]) - center = np.array([0, 0, -radius]) - center = r.apply(center) - plane_center = np.array( - [ - center, - ] - * n - ) - - plane = { - "plane_U": plane_U, - "plane_V": plane_V, - "rotation": r, - "plane_normal": plane_normal, - "plane_center": plane_center, - } - # logger.info(f'Initialisation of plane {plane} ') - return plane - - -def init_plane2(n, angle, radius, spect_table_shift_mm): - """ - plane_U, plane_V, plane_point, plane_normal - """ - - n = int(n) plane_U = np.array([1, 0, 0]) plane_V = np.array([0, 1, 0]) - r1 = Rotation.from_euler("x", 90, degrees=True) - r2 = Rotation.from_euler("z", angle, degrees=True) - r = r2 * r1 + # r1 = Rotation.from_euler("z", 180, degrees=True) + # r2 = Rotation.from_euler("z", angle, degrees=True) + # r1 = Rotation.from_euler("xz", (180, 0), degrees=True) # <--- this is the correct one + r1 = Rotation.from_euler("yz", (90, -90), degrees=True) + r2 = Rotation.from_euler("yx", (90, 90), degrees=True) + r_a = Rotation.from_euler("z", angle, degrees=True) + # r = r2 * r1 + r = r_a * r1 * r2 plane_U = r.apply(plane_U) plane_V = r.apply(plane_V) # normal vector is the cross product of two direction vectors on the plane plane_normal = np.cross(plane_U, plane_V) - plane_normal = np.array([plane_normal] * n) + plane_normal = np.array([plane_normal] * int(n)) + # axial is Z axis center = np.array([0, -spect_table_shift_mm, -radius]) center = r.apply(center) - plane_center = np.array([center] * n) + plane_center = np.array([center] * int(n)) plane = { "plane_U": plane_U, @@ -785,114 +714,7 @@ def init_plane2(n, angle, radius, spect_table_shift_mm): return plane -def project_on_plane(x, plane, image_plane_size_mm, debug=False): - """ - Project the x points (Ekine X Y Z dX dY dZ) - on the image plane defined by plane_U, plane_V, plane_center, plane_normal - """ - - logger.info(f"Projection of {len(x)} particles on the plane") - logger.info(f"Plane size is {image_plane_size_mm} mm") - - # shorter variable names - - # n is the normal plane, duplicated n times - n = plane["plane_normal"][0 : len(x)] - - # c0 is the center of the plane, duplicated n times - c0 = plane["plane_center"][0 : len(x)] - - # r is the rotation matrix of the plane, according to the current rotation angle (around Y) - r = plane["rotation"][0 : len(x)] - - # p is the set of points position generated by the GAN - p = x[:, 1:4] - - # u is the set of points direction generated by the GAN - u = x[:, 4:7] - - # w is the set of vectors from all points to the plane center - w = p - c0 - - # project to plane - ## dot product : out = (x*y).sum(-1) - # https://rosettacode.org/wiki/Find_the_intersection_of_a_line_with_a_plane#Python - # http://geomalgorithms.com/a05-_intersect-1.html - # https://github.com/pytorch/pytorch/issues/18027 - ndotu = (n * u).sum(-1) # dot product between normal plane (n) and direction (u) - si = ( - -(n * w).sum(-1) / ndotu - ) # dot product between normal plane and vector from plane to point (w) - - # only positive (direction to the plane) - mask = si > 0 - mw = w[mask] - mu = u[mask] - mc0 = c0[mask] - mn = n[mask] - mx = x[mask] - mp = p[mask] - msi = si[mask] - mnb = len(msi) - logger.info(f"Remove negative direction, remains {mnb}/{len(x)}") - - # si is a (nb) size vector, expand it to (nb x 3) - msi = np.array([msi] * 3).T - - # intersection between point-direction and plane - psi = mp + msi * mu - - # apply the inverse of the rotation - ri = r.inv() - psip = ri.apply(psi) # - offset - - # remove out of plane (needed ??) - sizex = image_plane_size_mm[0] / 2.0 - sizey = image_plane_size_mm[1] / 2.0 - mask1 = psip[:, 0] < sizex - mask2 = psip[:, 0] > -sizex - mask3 = psip[:, 1] < sizey - mask4 = psip[:, 1] > -sizey - m = mask1 & mask2 & mask3 & mask4 - psip = psip[m] - psi = psi[m] - mp = mp[m] - mu = mu[m] - mx = mx[m] - mc0 = mc0[m] - nb = len(psip) - logger.info(f"Remove points that are out of detector, remains {nb}/{len(x)}") - - # reshape results - pu = psip[:, 0].reshape((nb, 1)) # u - pv = psip[:, 1].reshape((nb, 1)) # v - y = np.concatenate((pu, pv), axis=1) - - # rotate direction according to the plane - mup = ri.apply(mu) - norm = np.linalg.norm(mup, axis=1, keepdims=True) - mup = mup / norm - dx = mup[:, 0] - dy = mup[:, 1] - - # FIXME -> clip arcos -1;1 ? - - # convert direction into theta/phi - # theta is acos(dy) - # phi is acos(dx) - theta = np.degrees(np.arccos(dy)).reshape((nb, 1)) - phi = np.degrees(np.arccos(dx)).reshape((nb, 1)) - y = np.concatenate((y, theta), axis=1) - y = np.concatenate((y, phi), axis=1) - - # concat the E - E = mx[:, 0].reshape((nb, 1)) - data = np.concatenate((y, E), axis=1) - - return data - - -def project_on_plane2(x, plane, image_plane_size_mm): +def project_on_plane(x, plane, image_plane_size_mm): """ Project the x points (Ekine X Y Z dX dY dZ) on the image plane defined by plane_U, plane_V, plane_center, plane_normal @@ -911,7 +733,7 @@ def project_on_plane2(x, plane, image_plane_size_mm): p = x[:, 1:4] # FIXME indices of the position # u is the set of points direction generated by the GAN - u = x[:, 4:7] # FIXME indices of the position + u = x[:, 4:7] # FIXME indices of the direction # w is the set of vectors from all points to the plane center w = p - c0 @@ -993,95 +815,6 @@ def project_on_plane2(x, plane, image_plane_size_mm): return data -def gaga_garf_generate_image(p): - # param - gan_params = p["gan_params"] - G = p["G"] - D = p["D"] - batch_size = p["batch_size"] - gan_batch_size = p["gan_batch_size"] - plane = p["plane"] - image_plane_size_mm = p["image_plane_size_mm"] - debug = p["debug"] - garf_nn = p["garf_nn"] - garf_model = p["garf_model"] - garf_param = p["garf_param"] - pbar = p["pbar"] - n = p["n"] - - ev = 0 - images = [] - sq_images = [] - while ev < n: - # check generation of the exact nb of samples - current_batch_size = batch_size - if current_batch_size > n - ev: - current_batch_size = n - ev - - # Step 1 : GAN - t1 = time.time() - logger.info(f"Generating {current_batch_size} events") - x = gaga.generate_samples2( - gan_params, - G, - D, - current_batch_size, - gan_batch_size, - normalize=False, - to_numpy=True, - ) - # print('batch / x', current_batch_size, len(x)) - logger.info("Computation time: {0:.3f} sec".format(time.time() - t1)) - - # Step 2 : Projection - t1 = time.time() - px = gaga.project_on_plane( - x, plane, image_plane_size_mm=image_plane_size_mm, debug=debug - ) - logger.info("Computation time: {0:.3f} sec".format(time.time() - t1)) - - # Step3 : GARF - # output image expressed in counts/samples (generated samples) - t1 = time.time() - logger.info(f"Building image with {len(px)}/{current_batch_size} particles") - garf_param["N_dataset"] = current_batch_size - img, sq_img = garf.build_arf_image_with_nn( - garf_nn, garf_model, px, garf_param, verbose=False, debug=debug - ) - images.append(img) - sq_images.append(sq_img) - logger.info("Computation time: {0:.3f} sec".format(time.time() - t1)) - - ev += current_batch_size - pbar.update(current_batch_size) - ev = min(ev, n) - logger.info("") - - # mean images - im_iter = iter(images) - im = next(im_iter) - data = itk.GetArrayFromImage(im) - for im in im_iter: - d = itk.GetArrayViewFromImage(im) - data += d - data = data / len(images) - img = itk.GetImageFromArray(data) - img.CopyInformation(images[0]) - - # mean images - im_iter = iter(sq_images) - im = next(im_iter) - data = itk.GetArrayFromImage(im) - for im in im_iter: - d = itk.GetArrayViewFromImage(im) - data += d - data = data / len(sq_images) - sq_img = itk.GetImageFromArray(data) - sq_img.CopyInformation(sq_images[0]) - - return img, sq_img - - def append_gaussian(data, mean, cov, n, vx=None, vy=None): x, y = np.random.multivariate_normal(mean, cov, n).T d = np.column_stack((x, y)) diff --git a/gaga_phsp/gaga_helpers_gate.py b/gaga_phsp/gaga_helpers_gate.py new file mode 100644 index 0000000..102a933 --- /dev/null +++ b/gaga_phsp/gaga_helpers_gate.py @@ -0,0 +1,242 @@ +import numpy as np +import gaga_phsp as gaga +import garf +from garf.helpers import get_gpu_device +from scipy.spatial.transform import Rotation +import itk +import opengate.sources.gansources as gansources +from tqdm import tqdm + + +def voxelized_source_generator(source_filename): + gen = gansources.VoxelizedSourceConditionGenerator( + source_filename, use_activity_origin=True + ) + gen.compute_directions = True + return gen.generate_condition + + +def gaga_garf_generate_spect_initialize(gaga_user_info, garf_user_info): + # ensure int + gaga_user_info.batch_size = int(gaga_user_info.batch_size) + garf_user_info.batch_size = int(garf_user_info.batch_size) + + # load gaga pth + gaga_params, G, D, _ = gaga.load(gaga_user_info.pth_filename) + gaga_user_info.gaga_params = gaga_params + gaga_user_info.G = G + gaga_user_info.D = D + + # load garf pth + nn, model = garf.load_nn(garf_user_info.pth_filename, verbose=False) + garf_user_info.nn = nn + garf_user_info.model_data = model + + # set gpu/cpu for garf + current_gpu_mode, current_gpu_device = get_gpu_device(garf_user_info.gpu_mode) + garf_user_info.nn.model_data["current_gpu_mode"] = current_gpu_mode + garf_user_info.nn.model_data["current_gpu_device"] = current_gpu_device + + # garf image plane rotation + r = Rotation.from_euler("x", 0, degrees=True) + garf_user_info.plane_rotation = r + + # image plane size + garf_user_info.nb_energy_windows = garf_user_info.nn.model_data["n_ene_win"] + size = garf_user_info.image_size + spacing = garf_user_info.image_spacing + garf_user_info.image_plane_size_mm = np.array( + [size[0] * spacing[0], size[1] * spacing[1]] + ) + + # size and spacing must be np + garf_user_info.image_spacing = np.array(garf_user_info.image_spacing) + garf_user_info.image_hspacing = garf_user_info.image_spacing / 2.0 + garf_user_info.image_plane_hsize_mm = garf_user_info.image_plane_size_mm / 2 + + +def do_nothing(a): + pass + + +def gaga_garf_generate_spect( + gaga_user_info, garf_user_info, n, angle_rotations, verbose=True +): + # n must be int + n = int(n) + + # allocate the initial list of images : + # number of angles x number of energy windows x image size + nbe = garf_user_info.nb_energy_windows + size = garf_user_info.image_size + spacing = garf_user_info.image_spacing + data_size = [len(angle_rotations), nbe, size[0], size[1]] + data_img = np.zeros(data_size, dtype=np.float64) + + # verbose + if verbose: + print(f"GAGA pth = {gaga_user_info.pth_filename}") + print(f"GARF pth = {garf_user_info.pth_filename}") + print(f"GARF hist slice = {garf_user_info.hit_slice}") + print(f"Activity source = {gaga_user_info.activity_source}") + print(f"Number of energy windows = {garf_user_info.nb_energy_windows}") + print(f"Image plane size (pixel) = {garf_user_info.image_size}") + print(f"Image plane spacing (mm) = {spacing}") + print(f"Image plane size (mm) = {garf_user_info.image_plane_size_mm}") + print(f"Number of angles = {len(angle_rotations)}") + print(f"GAGA batch size = {gaga_user_info.batch_size:.1e}") + print(f"GARF batch size = {garf_user_info.batch_size:.1e}") + print( + f"GAGA GPU mode = {gaga_user_info.gaga_params['current_gpu_mode']}" + ) + print( + f"GARF GPU mode = {garf_user_info.nn.model_data['current_gpu_mode']}" + ) + + # create the planes for each angle (with max number of values = batch_size) + planes = [] + gaga_user_info.batch_size = int(gaga_user_info.batch_size) + projected_points = [None] * len(angle_rotations) + for rot in angle_rotations: + plane = garf.arf_plane_init(garf_user_info, rot, gaga_user_info.batch_size) + planes.append(plane) + + # initialize the condition generator + f = gaga_user_info.activity_source + cond_generator = gansources.VoxelizedSourceConditionGenerator( + f, use_activity_origin=False # FIXME true or false ? + ) + cond_generator.compute_directions = True + cond_generator.translation = gaga_user_info.cond_translation + + # prepare verbose + verb_gaga_1 = do_nothing + verb_gaga_2 = do_nothing + verb_garf_1 = do_nothing + if gaga_user_info.verbose > 0: + verb_gaga_1 = tqdm.write + if gaga_user_info.verbose > 1: + verb_gaga_2 = tqdm.write + if garf_user_info.verbose > 0: + verb_garf_1 = tqdm.write + + # loop on GAGA batches + current_n = 0 + pbar = tqdm(total=n) + nb_hits_on_plane = [0] * len(angle_rotations) + nb_detected_hits = [0] * len(angle_rotations) + while current_n < n: + # check generation of the exact nb of samples + current_batch_size = gaga_user_info.batch_size + if current_batch_size > n - current_n: + current_batch_size = n - current_n + verb_gaga_1(f"Current event = {current_n}/{n}") + + # generate samples + x = gaga.generate_samples_with_vox_condition( + gaga_user_info, cond_generator, current_batch_size + ) + + # FIXME filter Energy too low (?) + + # generate projections + for i in range(len(angle_rotations)): + # project on plane + plane = planes[i] + px = garf.arf_plane_project(x, plane, garf_user_info.image_plane_size_mm) + verb_gaga_2(f"\tAngle {i}, number of gamma reaching the plane = {len(px)}") + nb_hits_on_plane[i] += len(px) + if len(px) == 0: + continue + + # Store projected points until garf_batch_size is full before build image + cpx = projected_points[i] + if cpx is None: + if len(px) > garf_user_info.batch_size: + print( + f"Cannot use GARF, {len(px)} points while batch size is {garf_user_info.batch_size}" + ) + exit(-1) + projected_points[i] = px + else: + if len(cpx) + len(px) > garf_user_info.batch_size: + # build image + image = data_img[i] + verb_garf_1( + f"\tGARF rotation {i}: update image with {len(cpx)} hits ({current_n}/{n})" + ) + nb_detected_hits[i] += garf.build_arf_image_from_projected_points( + garf_user_info, cpx, image + ) + projected_points[i] = px + else: + projected_points[i] = np.concatenate((cpx, px), axis=0) + + # next angles index + i += 1 + + # iterate + current_n += current_batch_size + pbar.update(current_batch_size) + + # remaining projected points + for i in range(len(angle_rotations)): + cpx = projected_points[i] + if cpx is None or len(cpx) == 0: + continue + if garf_user_info.verbose > 0: + print(f"GARF rotation {i}: update image with {len(cpx)} hits (final)") + image = data_img[i] + nb_detected_hits[i] = garf.build_arf_image_from_projected_points( + garf_user_info, cpx, image + ) + + if verbose: + for i in range(len(angle_rotations)): + print(f"Angle {i}, nb of hits on plane = {nb_hits_on_plane[i]}") + print(f"Angle {i}, nb of detected hits = {nb_detected_hits[i]}") + + # Remove first slice (nb of hits) + if not garf_user_info.hit_slice: + data_img = data_img[:, 1:, :] + + # Final list of images + images = [] + for i in range(len(angle_rotations)): + img = itk.image_from_array(data_img[i]) + spacing = [spacing[0], spacing[1], 1] + origin = [ + -size[0] * spacing[0] / 2 + spacing[0] / 2, + -size[1] * spacing[1] / 2 + spacing[1] / 2, + 0, + ] + img.SetSpacing(spacing) + img.SetOrigin(origin) + images.append(img) + i += 1 + + return images + + +def generate_samples_with_vox_condition(gaga_user_info, cond_generator, n): + # generate conditions + cond = cond_generator.generate_condition(n) + + # generate samples + x = gaga.generate_samples3( + gaga_user_info.gaga_params, + gaga_user_info.G, + n=n, + cond=cond, + ) + + # move backward + pos_index = 1 # FIXME + dir_index = 4 + position = x[:, pos_index : pos_index + 3] + direction = x[:, dir_index : dir_index + 3] + x[:, pos_index : pos_index + 3] = ( + position - gaga_user_info.backward_distance * direction + ) + + return x diff --git a/gaga_phsp/gaga_helpers_plot.py b/gaga_phsp/gaga_helpers_plot.py index cd0660e..6690431 100644 --- a/gaga_phsp/gaga_helpers_plot.py +++ b/gaga_phsp/gaga_helpers_plot.py @@ -1,16 +1,7 @@ import numpy as np -import torch -import gaga_phsp -import datetime -import time -import garf import gatetools.phsp as phsp from scipy.stats import kde from matplotlib import pyplot as plt -from scipy.stats import entropy -from scipy.spatial.transform import Rotation -import logging -import sys def plot_epoch(ax, params, optim, filename): @@ -19,26 +10,26 @@ def plot_epoch(ax, params, optim, filename): 3 panels : all epoch / first 20% / last 1% """ - x1 = np.asarray(optim['d_loss_real']) - x2 = np.asarray(optim['d_loss_fake']) + x1 = np.asarray(optim["d_loss_real"]) + x2 = np.asarray(optim["d_loss_fake"]) # x = -np.add(x1,x2) - x = -np.asarray(optim['d_loss']) # with grad penalty + x = -np.asarray(optim["d_loss"]) # with grad penalty - epoch = np.arange(params['start_epoch'], params['end_epoch'], 1) + epoch = np.arange(params["start_epoch"], params["end_epoch"], 1) a = ax # [0] l = filename - a.plot(epoch, x, '-', label='D_loss (GP) ' + l) + a.plot(epoch, x, "-", label="D_loss (GP) " + l) z = np.zeros_like(x) - a.set_xlabel('epoch') - a.plot(epoch, z, '--') + a.set_xlabel("epoch") + a.plot(epoch, z, "--") a.legend() # print(params['validation_dataset']) - if not 'validation_dataset' in params or params['validation_dataset'] == None: + if not "validation_dataset" in params or params["validation_dataset"] == None: return - print('validation') - x = -np.asarray(optim['validation_d_loss']) - a.plot(epoch, x, '-', label='Valid') + print("validation") + x = -np.asarray(optim["validation_d_loss"]) + a.plot(epoch, x, "-", label="Valid") a.legend() return @@ -46,28 +37,28 @@ def plot_epoch(ax, params, optim, filename): a = ax[1] n = int(len(x) * 0.2) # first 20% xc = x[0:n] - a.plot(epoch, x, '-', label='D_loss ' + l) + a.plot(epoch, x, "-", label="D_loss " + l) z = np.zeros_like(xc) - a.set_xlabel('epoch') + a.set_xlabel("epoch") a.set_xlim((0, n)) ymin = np.amin(xc) ymax = np.amax(xc) a.set_ylim((ymin, ymax)) - a.plot(z, '--') + a.plot(z, "--") a.legend() a = ax[2] n = max(10, int(len(x) * 0.01)) # last 1% xc = x - a.plot(xc, '.-', label='D_loss ' + l) + a.plot(xc, ".-", label="D_loss " + l) z = np.zeros_like(xc) - a.set_xlabel('epoch') + a.set_xlabel("epoch") a.set_xlim((len(xc) - n, len(xc))) - xc = x[len(x) - n:len(x)] + xc = x[len(x) - n : len(x)] ymin = np.amin(xc) ymax = np.amax(xc) a.set_ylim((ymin, ymax)) - a.plot(epoch, z, '--') + a.plot(epoch, z, "--") a.legend() @@ -76,11 +67,13 @@ def plot_epoch2(ax, params, optim, filename): Plot D loss wrt to epoch """ - dlr = np.asarray(optim['d_loss_real']) - dlf = np.asarray(optim['d_loss_fake']) - dl = np.asarray(optim['d_loss']) # with grad penalty - gl = np.asarray(optim['g_loss']) - epoch = np.arange(params['start_epoch'], params['end_epoch'], 1) + dlr = np.asarray(optim["d_loss_real"]) + dlf = np.asarray(optim["d_loss_fake"]) + dl = np.asarray(optim["d_loss"]) # with grad penalty + gl = np.asarray(optim["g_loss"]) + epoch = np.arange(params["start_epoch"], params["end_epoch"], 1) + + print(f"dl size = {len(dl)}") # one epoch is when all the training dataset is seen step = 1 # int(params['training_size'] / params['batch_size']) @@ -95,20 +88,20 @@ def plot_epoch2(ax, params, optim, filename): l = filename # a.plot(epoch, dlr, '-', label='D_loss_real' + l, alpha=0.5) # a.plot(epoch, dlf, '-', label='D_loss_fake' + l, alpha=0.5) - a.plot(epoch, dl, '-', label='D_loss (GP) ' + l) - a.plot(epoch, gl, '-', label='G_loss ' + l, alpha=0.5) + a.plot(epoch, dl, "-", label="D_loss (GP) " + l) + a.plot(epoch, gl, "-", label="G_loss " + l, alpha=0.5) z = np.zeros_like(dl) - a.set_xlabel('epoch') - a.plot(epoch, z, '--') + a.set_xlabel("epoch") + a.plot(epoch, z, "--") a.legend() # print(params['validation_dataset']) - if not 'validation_dataset' in params or params['validation_dataset'] is None: + if not "validation_dataset" in params or params["validation_dataset"] is None: return - print('Plot with validation dataset') - x = np.asarray(optim['validation_d_loss']) + print("Plot with validation dataset") + x = np.asarray(optim["validation_d_loss"]) x = x[::step] - a.plot(epoch, x, '-', label='Valid ' + l) + a.plot(epoch, x, "-", label="Valid " + l) a.legend() return @@ -116,55 +109,55 @@ def plot_epoch2(ax, params, optim, filename): a = ax[1] n = int(len(x) * 0.2) # first 20% xc = x[0:n] - a.plot(epoch, x, '-', label='D_loss ' + l) + a.plot(epoch, x, "-", label="D_loss " + l) z = np.zeros_like(xc) - a.set_xlabel('epoch') + a.set_xlabel("epoch") a.set_xlim((0, n)) ymin = np.amin(xc) ymax = np.amax(xc) a.set_ylim((ymin, ymax)) - a.plot(z, '--') + a.plot(z, "--") a.legend() a = ax[2] n = max(10, int(len(x) * 0.01)) # last 1% xc = x - a.plot(xc, '.-', label='D_loss ' + l) + a.plot(xc, ".-", label="D_loss " + l) z = np.zeros_like(xc) - a.set_xlabel('epoch') + a.set_xlabel("epoch") a.set_xlim((len(xc) - n, len(xc))) - xc = x[len(x) - n:len(x)] + xc = x[len(x) - n : len(x)] ymin = np.amin(xc) ymax = np.amax(xc) a.set_ylim((ymin, ymax)) - a.plot(epoch, z, '--') + a.plot(epoch, z, "--") a.legend() def plot_epoch_wasserstein(ax, optim, filename): """ - Plot wasserstein + Plot wasserstein """ - y = np.asarray(optim['w_value']) - x = np.asarray(optim['w_epoch']) + y = np.asarray(optim["w_value"]) + x = np.asarray(optim["w_epoch"]) if len(x) < 1: return a = ax[0].twinx() - a.plot(x, y, '-', color='r', label='W') + a.plot(x, y, "-", color="r", label="W") a.legend() a = ax[1].twinx() - a.plot(x, y, '-', color='r', label='W') + a.plot(x, y, "-", color="r", label="W") a.legend() a = ax[2].twinx() - a.plot(x, y, '.-', color='r', label='W') + a.plot(x, y, ".-", color="r", label="W") a.legend() -def fig_plot_marginal(x, k, keys, ax, i, nb_bins, color, r='', lab=''): +def fig_plot_marginal(x, k, keys, ax, i, nb_bins, color, r="", lab=""): a = phsp.fig_get_sub_fig(ax, i) index = keys.index(k) if len(x[0]) > 1: @@ -172,23 +165,29 @@ def fig_plot_marginal(x, k, keys, ax, i, nb_bins, color, r='', lab=''): else: d = x # label = ' {} $\mu$={:.2f} $\sigma$={:.2f}'.format(k, np.mean(d), np.std(d)) - label = f'{lab} {k} $\mu$={np.mean(d):.2f} $\sigma$={np.std(d):.2f}' - if r != '': - a.hist(d, nb_bins, - # density=True, - histtype='stepfilled', - facecolor=color, - alpha=0.5, - range=r, - label=label) + label = f"{lab} {k} $\mu$={np.mean(d):.2f} $\sigma$={np.std(d):.2f}" + if r != "": + a.hist( + d, + nb_bins, + # density=True, + histtype="stepfilled", + facecolor=color, + alpha=0.5, + range=r, + label=label, + ) else: - a.hist(d, nb_bins, - # density=True, - histtype='stepfilled', - facecolor=color, - alpha=0.5, - label=label) - a.set_ylabel('Counts') + a.hist( + d, + nb_bins, + # density=True, + histtype="stepfilled", + facecolor=color, + alpha=0.5, + label=label, + ) + a.set_ylabel("Counts") a.legend() @@ -198,36 +197,30 @@ def fig_plot_marginal_2d(x, k1, k2, keys, ax, i, nbins, color): d1 = x[:, index1] index2 = keys.index(k2) d2 = x[:, index2] - label = '{} {}'.format(k1, k2) + label = "{} {}".format(k1, k2) - ptype = 'scatter' + ptype = "scatter" # ptype = 'hist' - if ptype == 'scatter': - a.scatter(d1, d2, color=color, - alpha=0.25, - edgecolor='none', - s=1) + if ptype == "scatter": + a.scatter(d1, d2, color=color, alpha=0.25, edgecolor="none", s=1) - if ptype == 'hist': + if ptype == "hist": cmap = plt.cm.Greens - if color == 'r': + if color == "r": cmap = plt.cm.Reds - a.hist2d(d1, d2, - bins=(nbins, nbins), - alpha=0.7, - cmap=cmap) + a.hist2d(d1, d2, bins=(nbins, nbins), alpha=0.7, cmap=cmap) - if ptype == 'density': + if ptype == "density": x = d1 y = d2 - print('kde') + print("kde") k = kde.gaussian_kde([x, y]) - xi, yi = np.mgrid[x.min():x.max():nbins * 1j, y.min():y.max():nbins * 1j] + xi, yi = np.mgrid[ + x.min() : x.max() : nbins * 1j, y.min() : y.max() : nbins * 1j + ] zi = k(np.vstack([xi.flatten(), yi.flatten()])) - a.pcolormesh(xi, yi, - zi.reshape(xi.shape), - alpha=0.5) + a.pcolormesh(xi, yi, zi.reshape(xi.shape), alpha=0.5) a.set_xlabel(k1) a.set_ylabel(k2) @@ -242,7 +235,7 @@ def fig_plot_diff_2d(x, y, keys, kk, ax, fig, nb_bins): x2 = x[:, index2] y1 = y[:, index1] y2 = y[:, index2] - label = '{} {}'.format(k1, k2) + label = "{} {}".format(k1, k2) # compute histo H_x, xedges_x, yedges_x = np.histogram2d(x1, x2, bins=nb_bins) @@ -250,7 +243,7 @@ def fig_plot_diff_2d(x, y, keys, kk, ax, fig, nb_bins): # make diff # H = (H_y - H_x)/H_y - np.seterr(divide='ignore', invalid='ignore') + np.seterr(divide="ignore", invalid="ignore") H = np.divide(H_y - H_x, H_y) # plot @@ -276,16 +269,18 @@ def fig_plot_projected(data): f, ax = plt.subplots(2, 2, figsize=(10, 10)) - n, bins, patches = ax[0, 0].hist(theta, b, density=True, facecolor='g', alpha=0.35) - n, bins, patches = ax[0, 1].hist(phi, b, density=True, facecolor='g', alpha=0.35) - n, bins, patches = ax[1, 0].hist(E * 1000, b, density=True, facecolor='b', alpha=0.35) - ax[1, 1].scatter(x, y, color='r', alpha=0.35, s=1) - - ax[0, 0].set_xlabel('Theta angle (deg)') - ax[0, 1].set_xlabel('Phi angle (deg)') - ax[1, 0].set_xlabel('Energy (keV)') - ax[1, 1].set_xlabel('X') - ax[1, 1].set_ylabel('Y') + n, bins, patches = ax[0, 0].hist(theta, b, density=True, facecolor="g", alpha=0.35) + n, bins, patches = ax[0, 1].hist(phi, b, density=True, facecolor="g", alpha=0.35) + n, bins, patches = ax[1, 0].hist( + E * 1000, b, density=True, facecolor="b", alpha=0.35 + ) + ax[1, 1].scatter(x, y, color="r", alpha=0.35, s=1) + + ax[0, 0].set_xlabel("Theta angle (deg)") + ax[0, 1].set_xlabel("Phi angle (deg)") + ax[1, 0].set_xlabel("Energy (keV)") + ax[1, 1].set_xlabel("X") + ax[1, 1].set_ylabel("Y") plt.tight_layout() plt.show() diff --git a/gaga_phsp/gaga_helpers_spect.py b/gaga_phsp/gaga_helpers_spect.py index f7b9056..d4ab129 100755 --- a/gaga_phsp/gaga_helpers_spect.py +++ b/gaga_phsp/gaga_helpers_spect.py @@ -2,7 +2,6 @@ # -*- coding: utf-8 -*- import scipy -import torch import numpy as np import gaga_phsp as gaga diff --git a/gaga_phsp/gaga_helpers_tests.py b/gaga_phsp/gaga_helpers_tests.py new file mode 100644 index 0000000..4546394 --- /dev/null +++ b/gaga_phsp/gaga_helpers_tests.py @@ -0,0 +1,65 @@ + +import os +import inspect +import colored +import sys +import scipy +import numpy as np + +try: + color_error = colored.fg("red") + colored.attr("bold") + color_warning = colored.fg("orange_1") + color_ok = colored.fg("green") +except AttributeError: + # new syntax in colored>=1.5 + color_error = colored.fore("red") + colored.style("bold") + color_warning = colored.fore("orange_1") + color_ok = colored.fore("green") + + + +def fatal(s): + caller = inspect.getframeinfo(inspect.stack()[1][0]) + ss = f"(in {caller.filename} line {caller.lineno})" + ss = colored.stylize(ss, color_error) + print(ss) + s = colored.stylize(s, color_error) + print(s) + raise Exception(s) + +def run_and_check(cmd): + print() + print(f'Running : {cmd}') + r = os.system(f"{cmd} ") + if r != 0: + fatal(f"Command error : {cmd}") + + +def test_ok(is_ok=False): + if is_ok: + s = "Great, tests are ok." + s = "\n" + colored.stylize(s, color_ok) + print(s) + # sys.exit(0) + else: + s = "Error during the tests !" + s = "\n" + colored.stylize(s, color_error) + print(s) + sys.exit(-1) + +def compare_sampled_points(keys, real, fake, wtol=0.1, tol=0.08): + for i in range(len(keys)): + w = scipy.stats.wasserstein_distance(real[:,i], fake[:,i]) + print(f"({i}) Key {keys[i]}, wass = {w:.2f} tol = {wtol:.2f}") + if w > wtol: + fatal(f"Difference between real and fake too large {w} vs {wtol}") + real_mean = np.mean(real[:,i]) + real_std = np.std(real[:,i]) + fake_mean = np.mean(fake[:,i]) + fake_std = np.std(fake[:,i]) + d_mean = np.fabs((real_mean - fake_mean) / real_mean) + d_std = np.fabs((real_std - fake_std) / real_std) + print(f"({i}) Mean real vs fake : {real_mean:.2f} {fake_mean:.2f} {d_mean*100:.2f}%") + print(f"({i}) Std real vs fake : {real_std:.2f} {fake_std:.2f} {d_std*100:.2f}%") + if d_mean > tol or d_std > tol: + fatal(f"Difference between real and fake too large {d_mean} {d_std} vs {tol}") \ No newline at end of file diff --git a/gaga_phsp/gaga_model.py b/gaga_phsp/gaga_model.py index 7693731..3338147 100644 --- a/gaga_phsp/gaga_model.py +++ b/gaga_phsp/gaga_model.py @@ -2,7 +2,6 @@ import torch import gaga_phsp as gaga from torch import Tensor -from types import MethodType class MyLeakyReLU(nn.Module): diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 0000000..f6c1689 --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,6 @@ +[build-system] +requires = [ + "setuptools>=42", + "wheel", +] +build-backend = "setuptools.build_meta" diff --git a/readme.md b/readme.md index 2d3958e..8820088 100644 --- a/readme.md +++ b/readme.md @@ -11,5 +11,4 @@ https://www.ncbi.nlm.nih.gov/pubmed/31470418 A method is proposed and evaluated to model large and inconvenient phase space files used in Monte Carlo simulations by a compact Generative Adversarial Network (GAN). The GAN is trained based on a phase space dataset to create a neural network, called Generator (G), allowing G to mimic the multidimensional data distribution of the phase space. At the end of the training process, G is stored with about 0.5 million weights, around 10MB, instead of few GB of the initial file. Particles are then generated with G to replace the phase space dataset.&#13; &#13; This concept is applied to beam models from linear accelerators (linacs) and from brachytherapy seed models. Simulations using particles from the reference phase space on one hand and those generated by the GAN on the other hand were compared. 3D distributions of deposited energy obtained from source distributions generated by the GAN were close to the reference ones, with less than 1\% of voxel-by-voxel relative difference. Sharp parts such as the brachytherapy emission lines in the energy spectra were not perfectly modeled by the GAN. Detailed statistical properties and limitations of the GAN-generated particles still require further investigation, but the proposed exploratory approach is already promising and paves the way for a wide range of applications -Examples : -https://github.com/OpenGATE/GateContrib/tree/master/dosimetry/gaga-phsp +Tests in opengate (https://github.com/OpenGATE/opengate): see test066. \ No newline at end of file diff --git a/setup.py b/setup.py index 7ad5ac7..b84f4bc 100644 --- a/setup.py +++ b/setup.py @@ -24,7 +24,7 @@ "colorama", "click", "scipy", - "garf>=2.4", + "garf>=2.5", "matplotlib", "gatetools", # 'torch' # the installation of torch is managed by garf @@ -35,10 +35,9 @@ "bin/gaga_plot", "bin/gaga_generate", "bin/gaga_gauss_test", + "bin/gaga_gauss_cond_test", "bin/gaga_gauss_plot", "bin/gaga_convert_pth_to_pt", - "bin/gaga_wasserstein", - "bin/gaga_garf_generate_img", "bin/gaga_pairs_to_tlor", "bin/gaga_tlor_to_pairs", "bin/gaga_pet_to_pairs_old", diff --git a/tests/json/cg1.json b/tests/json/cg1.json index cb3dce0..727b179 100644 --- a/tests/json/cg1.json +++ b/tests/json/cg1.json @@ -73,7 +73,7 @@ "#": "optimiser: decrease learning rate. 1000-0.2 means, that every 1000 step the lr is x 0.2", "#": "comment the following line to disable scheduler", - "schedule_learning_rate_step": 500, + "schedule_learning_rate_step": 10, "schedule_learning_rate_gamma": 0.9, "#": "optimiser: number of D and G update by epoch", @@ -81,7 +81,7 @@ "g_nb_update": 1, "#": "optimiser: max nb of epoch (iteration)", - "epoch": 10000, + "epoch": 100, "#": "optimiser: nb of samples by batch", "batch_size": 10000, diff --git a/tests/json/g1.json b/tests/json/g1.json index 2afdd85..96e4ef6 100644 --- a/tests/json/g1.json +++ b/tests/json/g1.json @@ -73,7 +73,7 @@ "#": "optimiser: decrease learning rate. 1000-0.2 means, that every 1000 step the lr is x 0.2", "#": "comment the following line to disable scheduler", - "schedule_learning_rate_step": 10000, + "schedule_learning_rate_step": 10, "schedule_learning_rate_gamma": 0.8, "#": "optimiser: number of D and G update by epoch", @@ -81,7 +81,7 @@ "g_nb_update": 1, "#": "optimiser: max nb of epoch (iteration)", - "epoch": 10000, + "epoch": 100, "#": "optimiser: nb of samples by batch", "batch_size": 10000, diff --git a/tests/readme_tests.md b/tests/readme_gaga_phsp_tests.md similarity index 93% rename from tests/readme_tests.md rename to tests/readme_gaga_phsp_tests.md index b3844b0..461ab85 100644 --- a/tests/readme_tests.md +++ b/tests/readme_gaga_phsp_tests.md @@ -1,4 +1,4 @@ -# Test 1 : Linac phsp +# (Test 1 : Linac phsp) ==> See GateBenchmark/t9_gaga_phsp @@ -20,9 +20,6 @@ train: gaga_train npy/gauss_v1.npy json/g1.json -f pth/ -pi epoch 1000 gaga_train npy/gauss_v2.npy json/g2.json -f pth/ -pi epoch 5000 -(linux is about 2 min for 1000) -(WARNING bug with mps on osx, much too slow) - result: gaga_gauss_plot npy/gauss_v1.npy pth/g1_GP_SquareHinge_1_1000.pth -n 1e4 @@ -56,8 +53,6 @@ train: gaga_train npy/xgauss_10_1e6.npy json/cg1.json -f pth -pi epoch 4000 -(OK with mps and with cuda) - result: # warning x and y not independent here! @@ -73,3 +68,4 @@ Convert root dataset to parametrisation, replace P_exit by P_ideal = P_exit - c gt_phsp_plot b.npy spect_training_dataset.root -n 1e4 +# Test 6 : TODO gaga_garf_generate_img \ No newline at end of file diff --git a/tests/test001_non_cond.py b/tests/test001_non_cond.py new file mode 100755 index 0000000..c052c7b --- /dev/null +++ b/tests/test001_non_cond.py @@ -0,0 +1,57 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- + +import gaga_phsp as gaga +import gatetools.phsp as phsp + +if __name__ == "__main__": + """ + Training : about 20-30 sec (gpu) + """ + + # input + phsp_filename = f"npy/gauss_v1.npy" + pth_filename = f"pth/test001_non_cond.pth" + + # step 1 + cmd = f"gaga_gauss_test {phsp_filename} -n 8e5 -t v1" + gaga.run_and_check(cmd) + + # step 2 + cmd = f"gaga_train {phsp_filename} json/g1.json -o {pth_filename} -pi epoch 20" + gaga.run_and_check(cmd) + + # step 3 + cmd = f"gaga_gauss_plot {phsp_filename} {pth_filename} -n 1e4" + gaga.run_and_check(cmd) + print("Results in cond.png") + + plt = pth_filename.replace(".pth", ".png") + cmd = f"gaga_plot {phsp_filename} {pth_filename} -o {plt}" + gaga.run_and_check(cmd) + print(f"Results in {plt}") + + # load phsp + n = 1e5 + print(f"Load phsp : {phsp_filename}") + real, r_keys, m = phsp.load(phsp_filename, nmax=n) + print(f"real shape {real.shape} {r_keys}") + + # load gaga + params, G, D, optim = gaga.load(pth_filename) + print(f"Keys : {params['keys']}") + print(f"Keys_list : {params['keys_list']}") + + # generate (non cond) + batch_size = 1e5 + fake = gaga.generate_samples_non_cond( + params, G, n, batch_size, normalize=False, to_numpy=True + ) + print(f"fake shape {fake.shape}") + + # compare fake and real + print() + gaga.compare_sampled_points(r_keys, real, fake, wtol=0.3, tol=0.08) + + # end + gaga.test_ok(True) diff --git a/tests/test002_cond.py b/tests/test002_cond.py new file mode 100755 index 0000000..bc5c7ef --- /dev/null +++ b/tests/test002_cond.py @@ -0,0 +1,57 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- + +import gaga_phsp as gaga +import gatetools.phsp as phsp + + +if __name__ == "__main__": + """ + Training : about 2 min (gpu) + """ + + # input + phsp_filename = f"npy/test002_cond.npy" + pth_filename = f"pth/test002_cond.pth" + + # step 1 + cmd = f"gaga_gauss_cond_test {phsp_filename} -n 4e4 -m 10" + gaga.run_and_check(cmd) + + # step 2 + cmd = f"gaga_train {phsp_filename} json/cg1.json -o {pth_filename} -pi epoch 30" + gaga.run_and_check(cmd) + + # step 3 + cmd = f"gaga_gauss_plot {phsp_filename} {pth_filename} -n 1e4" + gaga.run_and_check(cmd) + print("Results in cond.png") + + plt = pth_filename.replace(".pth", ".png") + cmd = f"gaga_plot {phsp_filename} {pth_filename} --cond_phsp {phsp_filename} -o {plt}" + gaga.run_and_check(cmd) + print(f"Results in {plt}") + + # load phsp + n = 1e5 + print(f"Load phsp : {phsp_filename}") + real, r_keys, m = phsp.load(phsp_filename, nmax=n, shuffle=True) + print(f"real shape {real.shape} {r_keys}") + cond = real[:, 2:4] + print(f"cond shape {cond.shape}") + + # load gaga + params, G, D, optim = gaga.load(pth_filename) + print(f"Keys : {params['keys']}") + print(f"Keys_list : {params['keys_list']}") + + # generate (non cond) + fake = gaga.generate_samples3(params, G, n, cond) + print(f"fake shape {fake.shape}") + + # compare fake and real + print() + gaga.compare_sampled_points(r_keys[0:2], real, fake, wtol=0.21, tol=0.03) + + # end + gaga.test_ok(True)