Skip to content

Commit

Permalink
improve visualizations
Browse files Browse the repository at this point in the history
  • Loading branch information
AlexanderVNikitin committed Jun 17, 2024
1 parent ae58487 commit bb988a6
Show file tree
Hide file tree
Showing 3 changed files with 33 additions and 24 deletions.
2 changes: 2 additions & 0 deletions tsgm/utils/file_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,8 @@ def extract_archive(from_path: str, to_path: T.Optional[str] = None, pwd: T.Opti
def download(url: str, path: str, md5: T.Optional[str] = None, max_attempt: int = 3) -> None:
logger.info(f"### Downloading from {url} ###")
os.makedirs(path, exist_ok=True)
if "?" in url:
url, _ = url.split("?")
resource_name = url.split("/")[-1]
path = os.path.join(path, resource_name)
for attempt in range(max_attempt):
Expand Down
30 changes: 23 additions & 7 deletions tsgm/utils/visualization.py
Original file line number Diff line number Diff line change
Expand Up @@ -246,7 +246,7 @@ def visualize_ts_lineplot(
:param unite_features: Whether to plot all features together or separately, defaults to True.
:type unite_features: bool, optional
:param legend_fontsize: Font size to use.
:type unite_features: int, optional
:type legend_fontsize: int, optional
:param tick_size: Font size for y-axis ticks.
:type tick_size: int, optional
"""
Expand All @@ -269,24 +269,40 @@ def visualize_ts_lineplot(
else:
for feat_id in range(ts.shape[2]):
sns.lineplot(
x=range(ts.shape[1]), y=ts[sample_id, :, feat_id], ax=axs[i]
x=range(ts.shape[1]), y=ts[sample_id, :, feat_id], ax=axs[i],
label="Generated"
)
if ys is not None:
axs[i].tick_params(labelsize=tick_size)
axs[i].tick_params(labelsize=tick_size, which="both")
if len(ys.shape) == 1:
axs[i].set_title(ys[sample_id], fontsize=legend_fontsize)
elif len(ys.shape) == 2:
ax2 = axs[i].twinx()
sns.lineplot(
x=range(ts.shape[1]),
y=ys[sample_id],
ax=axs[i].twinx(),
ax=ax2,
color="g",
label="Target variable",
label="Condition",
)
axs[i].twinx().tick_params(labelsize=tick_size)
# axs[i].twinx().yaxis.set_ticks_position('right')
ax2.tick_params(labelsize=tick_size)
if i == 0:
leg = ax2.legend(fontsize=legend_fontsize, loc='upper right')
for legobj in leg.legendHandles:
legobj.set_linewidth(2.0)
else:
ax2.get_legend().remove()
else:
raise ValueError("ys contains too many dimensions")
axs[i].legend(fontsize=legend_fontsize)
if i == 0:
leg = axs[i].legend(fontsize=legend_fontsize, loc='upper left')
for legobj in leg.legendHandles:
legobj.set_linewidth(2.0)
else:
axs[i].get_legend().remove()
if i != len(ids) - 1:
axs[i].set_xticks([])


def visualize_original_and_reconst_ts(
Expand Down
25 changes: 8 additions & 17 deletions tutorials/Model Selection.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
"id": "60611bed",
"metadata": {},
"source": [
"## Model selection.\n",
"# Model Selection in TSGM\n",
"This is a simple example of model selection through hyperparameter optimization."
]
},
Expand Down Expand Up @@ -36,7 +36,7 @@
"id": "0290b339",
"metadata": {},
"source": [
"#### 0. Install optuna\n",
"## 0. Install optuna\n",
"\n",
"Let's first install Optuna."
]
Expand All @@ -49,16 +49,7 @@
"outputs": [],
"source": [
"import sys\n",
"!{sys.executable} -m pip install optuna"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "01e0ec58",
"metadata": {},
"outputs": [],
"source": [
"!{sys.executable} -m pip install optuna\n",
"import optuna"
]
},
Expand All @@ -67,7 +58,7 @@
"id": "a375b04e",
"metadata": {},
"source": [
"#### 1. Load data\n",
"## 1. Load data\n",
"We are using a small dataset generated by the `tsgm.utils.gen_sine_dataset` function. We scale the features using `tsgm.utils.TSFeatureWiseScaler` to ensure each feature falls within the range of $[0, 1]$."
]
},
Expand All @@ -90,7 +81,7 @@
"id": "b76dd508",
"metadata": {},
"source": [
"#### 2. Define the optimization problem"
"## 2. Define the optimization problem"
]
},
{
Expand Down Expand Up @@ -127,8 +118,8 @@
"id": "5bbab28c",
"metadata": {},
"source": [
"#### 3. Define the search space for the optimizer\n",
"We can optimize the choice of the optimizer and its hyperparameters"
"## 3. Define the search space for the optimizer\n",
"We can optimize the choice of the optimizer and hyperparameters."
]
},
{
Expand Down Expand Up @@ -172,7 +163,7 @@
"id": "70786d6c",
"metadata": {},
"source": [
"#### 4. Define the objective function\n",
"## 4. Define the objective function\n",
"In the objective function, we load the data and use them to train a TimeGAN model (`tsgm.models.timeGAN.TimeGAN`) while changing its parameters. We use the fitted TimeGAN model to generate synthetic samples, and finally use them to compute the value of the metric we want to optimize. "
]
},
Expand Down

0 comments on commit bb988a6

Please sign in to comment.