From 7f65b6e50a867b0077a79fb412ab909b3714cf98 Mon Sep 17 00:00:00 2001 From: Soran Ghaderi Date: Fri, 2 Sep 2022 17:54:55 +0430 Subject: [PATCH] fix: Resolve incompatible arguments all encoder and decoder blocks' arguments were give more items than its required modify the number of steps to plot --- test_main.py | 32 +++++++++++++++++--------------- txplot/plot_pe.py | 2 +- 2 files changed, 18 insertions(+), 16 deletions(-) diff --git a/test_main.py b/test_main.py index 1b86965..a1b2163 100644 --- a/test_main.py +++ b/test_main.py @@ -1,4 +1,6 @@ -import pytest +# import pytest +import os + from layers.transformer_decoder_block import TransformerDecoderBlock from layers.transformer_encoder import TransformerEncoder from layers.transformer_encoder_block import TransformerEncoderBlock @@ -9,17 +11,17 @@ from layers.multihead_attention import MultiHeadAttention import numpy as np import tensorflow as tf +os.environ["CUDA_VISIBLE_DEVICES"] = "-1" - -@pytest.fixture() -def test_transpose_qkv(): - x = np.random.random([100, 10, 5]) - assert MultiHeadAttention.split_heads(x, x) +# @pytest.fixture() +# def test_transpose_qkv(): +# x = np.random.random([100, 10, 5]) +# assert MultiHeadAttention.split_heads(x, x) -encoding_dim, num_steps = 32, 60 -pos_encoding = PositionalEncoding(encoding_dim, 0) -X = pos_encoding(tf.zeros((2, num_steps, encoding_dim)), training=False) +depth, num_steps = 32, 50 +pos_encoding = PositionalEncoding(depth, 0) +X = pos_encoding(tf.zeros((2, num_steps, depth)), training=False) P = pos_encoding.P[:, : X.shape[1], :] plotter = Plot() plotter.plot_pe(np.arange(7, 11), P, num_steps) @@ -38,12 +40,12 @@ def test_transpose_qkv(): X = tf.ones((2, 100, 24)) valid_lens = tf.constant([3, 2]) norm_shape = [i for i in range(len(X.shape))][1:] -encoder_blk = TransformerEncoderBlock(24, 24, 24, 24, norm_shape, 48, 8, 0.5) -print(encoder_blk(X, valid_lens, training=False)) +encoder_block = TransformerEncoderBlock(24, norm_shape, 48, 8, 0.5) +print(encoder_block(X, valid_lens, training=False)) -encoder = TransformerEncoder(200, 24, 24, 24, 24, [1, 2], 48, 8, 2, 0.5) +encoder = TransformerEncoder(200, 24, [1, 2], 48, 8, 2, 0.5) print(encoder(tf.ones((2, 100)), valid_lens, training=False).shape, (2, 100, 24)) -decoder_blk = TransformerDecoderBlock(24, 24, 24, 24, [1, 2], 48, 8, 0.5, 0) -state = [encoder_blk(X, valid_lens), valid_lens, [None]] -print(decoder_blk(X, state, training=False)[0].shape, X.shape) +decoder_block = TransformerDecoderBlock(24, [1, 2], 48, 8, 0.5, 0) +state = [encoder_block(X, valid_lens), valid_lens, [None]] +print(decoder_block(X, state, training=False)[0].shape, X.shape) diff --git a/txplot/plot_pe.py b/txplot/plot_pe.py index 685d5a2..7148bc5 100644 --- a/txplot/plot_pe.py +++ b/txplot/plot_pe.py @@ -13,7 +13,7 @@ def plot_pe( num_steps, show_grid=True, ): - ax = plt.figure(figsize=(6, 2.5)) + ax = plt.figure(figsize=(6, 2.5), dpi=1000) lines = ["-", "--", "-.", ":"] self.line_cycler = cycle(lines)