Skip to content

Commit

Permalink
# modify generation.py
Browse files Browse the repository at this point in the history
  • Loading branch information
SHINDONGMYUNG committed May 14, 2021
1 parent acc3623 commit 7c54794
Showing 1 changed file with 33 additions and 15 deletions.
48 changes: 33 additions & 15 deletions envs/generation.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,3 @@
""" This code is a main script to run the RF generation module of DeepRF
to design a slice-selective excitation pulse"""

from os import path
import sys
sys.path.append(path.join(path.dirname(__file__), '..'))
Expand All @@ -16,6 +13,7 @@
import torch
import torch.nn as nn
import envs
import matplotlib.pyplot as plt
from utils.logger import Logger
from utils.summary import EvaluationMetrics
from scipy.io import loadmat
Expand All @@ -26,15 +24,15 @@
parser.add_argument('--gpu', type=int, default=0, help='activated GPU number')
parser.add_argument('--lr', type=float, default=1e-5, help='learning rate')
parser.add_argument('--el', type=int, default=32, help='episode length')
parser.add_argument('--amp', type=float, default=1e-3, help='amplitude scaling')
parser.add_argument('--ph', type=float, default=1e+1, help='phase scaling')
parser.add_argument('--amp', type=float, default=1e-3, help='amplitude scaling') # 1e-3
parser.add_argument('--ph', type=float, default=1e+1, help='phase scaling') # 1e+1
parser.add_argument('--hss', type=int, default=256, help='length of hidden state in GRU')
parser.add_argument('--batch', type=int, default=256, help='batch size (# episodes)')
parser.add_argument('--mb', type=int, default=2048, help='mini batch size (for 1 epoch)')
parser.add_argument('--v_hs', type=str, default='(128,64,32)', help='network structure')
parser.add_argument('--v_hs', type=str, default='(256,128,64,32)', help='network structure')
parser.add_argument('--gamma', type=float, default=1.0, help='discount factor')
parser.add_argument('--lmbda', type=float, default=0.95, help='lambda for GAE')
parser.add_argument('--eps', type=float, default=0.1, help='epsilon for PPO')
parser.add_argument('--eps', type=float, default=0.1, help='args.eps for PPO')
parser.add_argument('--epochs', type=int, default=4, help='number of epochs for gradient-descent')
parser.add_argument('--max', type=int, default=300, help='maximum number of iterations')
parser.add_argument('--kl', type=float, default=0.01, help='target KL value for early stopping')
Expand All @@ -45,7 +43,7 @@
parser.add_argument('--seed', type=int, default=1003, help='random seed')
parser.add_argument('--grad', type=float, default=1000.0, help='l2-norm for gradient clipping')
parser.add_argument('--save', type=int, default=300, help='save period in iterations')
parser.add_argument("--tag", type=str, default='exc_generation')
parser.add_argument("--tag", type=str, default='ppo_rnn_exc_21')
parser.add_argument("--log_level", type=int, default=10)
parser.add_argument("--debug", "-d", action="store_true")
parser.add_argument("--quiet", "-q", action="store_true")
Expand All @@ -66,17 +64,18 @@
ts = float(args.du) * 1e-3 / (float(args.sampling_rate))
max_rad = 2 * np.pi * 42.577 * 1e+6 * 0.2 * 1e-4 * ts

# random seed control
# random seed
tf.reset_default_graph()
tf.random.set_random_seed(args.seed)
np.random.seed(args.seed)
random.seed(args.seed)

# load reference SLR excitation pulse
# load reference pulse
device = 'cuda' if torch.cuda.is_available() else 'cpu'
preset = loadmat('../data/conv_rf/SLR_exc.mat')
ref_pulse = torch.unsqueeze(torch.from_numpy(np.array(preset['result'], dtype=np.float32)), dim=0).to(device)


# %% function definitions
def mlp(x, hidden_sizes=(32,), activation=tf.tanh, output_activation=None):
for h in hidden_sizes[:-1]:
Expand All @@ -89,7 +88,7 @@ def mlp(x, hidden_sizes=(32,), activation=tf.tanh, output_activation=None):
state_in = tf.placeholder(tf.float32, [None, args.hss]) # hidden state in/out size
adv = tf.placeholder(tf.float32, [None, ]) # advantages
prob = tf.placeholder(tf.float32, [None, ]) # history of probabilities
ret_in = tf.placeholder(tf.float32, [None, ]) # returns
ret_in = tf.placeholder(tf.float32, [None, ]) # returns for training

cell = tf.contrib.rnn.GRUCell(args.hss, reuse=tf.AUTO_REUSE)
state_out, _ = tf.nn.dynamic_rnn(cell, rnn_in, initial_state=state_in, dtype=tf.float32, time_major=True)
Expand Down Expand Up @@ -118,7 +117,7 @@ def mlp(x, hidden_sizes=(32,), activation=tf.tanh, output_activation=None):
tf.scalar_mul(1 / np.sqrt(np.pi * 2.0), tf.divide(tf.ones(tf.shape(std)),
std))), axis=1)
p_pi = p_pi_a + EPS
ratio = tf.exp(tf.log(p_pi) - tf.log(prob))
ratio = tf.exp(tf.log(p_pi) - tf.log(prob)) # ratio of pi_theta and pi_theta_old

approx_kl = -tf.reduce_mean(tf.log(p_pi) - tf.log(prob))
surr1 = tf.multiply(ratio, adv)
Expand Down Expand Up @@ -203,12 +202,13 @@ def mlp(x, hidden_sizes=(32,), activation=tf.tanh, output_activation=None):
input_list_tmp[:, :, 0] = input_list_tmp[:, :, 0] * args.amp
input_list_tmp[:, :, 1] = input_list_tmp[:, :, 1] * args.ph

input_list_tmp[:, :, 0] = 2 * (input_list_tmp[:, :, 0] / max_rad) - 1
input_list_tmp[:, :, 0] = 2 * (input_list_tmp[:, :, 0] / max_rad) - 1 # -1 ~ 1
# input_list_tmp[:, :, 1] = (input_list_tmp[:, :, 1] % (2 * np.pi) - np.pi) / np.pi # -1 ~ 1
input_list_tmp[:, :, 1] = (input_list_tmp[:, :, 1] - input_list_tmp[:, 0, 1, np.newaxis]) / np.pi

rf_list = np.vstack((rf_list, input_list_tmp))

with torch.no_grad():
with torch.no_grad(): # for inference only

m = nn.Parameter(torch.from_numpy(input_list_tmp[:, :, 0]).to(device))
p = nn.Parameter(torch.from_numpy(input_list_tmp[:, :, 1]).to(device))
Expand All @@ -223,7 +223,7 @@ def mlp(x, hidden_sizes=(32,), activation=tf.tanh, output_activation=None):
torch.cat((b1[:, 1, :], torch.fliplr(b1[:, 1, :])), dim=1)], dim=1)
b1 = torch.cat((ref_pulse, b1), dim=0)

# Virtual MRI simulation
# Simulation
t = 0
done = False
total_rewards = 0.0
Expand Down Expand Up @@ -261,23 +261,41 @@ def mlp(x, hidden_sizes=(32,), activation=tf.tanh, output_activation=None):

# Log summary statistics
if (it + 1) % args.log_step == 0:
# Inversion profile
profile = plt.figure(1)
plt.plot(np.concatenate((env.df[200:200 + 800], env.df[:200], env.df[200 + 800:200 + 1600])),
np.concatenate((mz[idx, 200:200 + 800], mz[idx, :200], mz[idx, 200 + 800:200 + 1600])))
# logger.image_summary(profile, it + 1, 'profile')

# RF pulse magnitude
t = np.linspace(0, env.du / len(env), len(env))
magnitude = b1[:, 0].detach().cpu().numpy()
fig_m = plt.figure(2)
plt.plot(t, magnitude[idx])
plt.ylim(-1, 1)
# logger.image_summary(fig_m, it + 1, 'magnitude')

# RF pulse phase
phase = b1[:, 1].detach().cpu().numpy()
fig_p = plt.figure(3)
plt.plot(t, phase[idx])
plt.ylim(-1, 1)
# logger.image_summary(fig_p, it + 1, 'phase')

if (it + 1) % args.save == 0:
array_dict = {'magnitude': magnitude, 'phase': phase, 'sar': sar, 'rf_list': rf_list[1:, ...],
'mz1': ripple1, 'mz2': ripple2, 'rew': rew, 'rew_list': rew_list}
# else:
# array_dict = {'magnitude': magnitude, 'phase': phase, 'sar': sar,
# 'mz1': ripple1, 'mz2': ripple2, 'rew': rew}
logger.savemat('pulse' + str(it + 1), array_dict)

logger.scalar_summary(info.val, it + 1)

info.reset()

# GAE
# val_list's size: (args.batch, args.el)
target = np.roll(val_list, -1, axis=1) * args.gamma
target[:, -1] = rew
delta = target - val_list
Expand Down

0 comments on commit 7c54794

Please sign in to comment.