Skip to content

Commit

Permalink
Merge pull request #14 from tensorops/lib_structuring_patch
Browse files Browse the repository at this point in the history
Resolve the issue mentioned in #11
  • Loading branch information
soran-ghaderi authored Sep 3, 2022
2 parents 51db09f + 17fa614 commit 40de79e
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 10 deletions.
20 changes: 11 additions & 9 deletions test_main.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import os

import numpy as np
import pytest
import tensorflow as tf
Expand All @@ -18,9 +20,9 @@ def test_transpose_qkv():
assert MultiHeadAttention.split_heads(x, x)


encoding_dim, num_steps = 32, 60
pos_encoding = PositionalEncoding(encoding_dim, 0)
X = pos_encoding(tf.zeros((2, num_steps, encoding_dim)), training=False)
depth, num_steps = 32, 50
pos_encoding = PositionalEncoding(depth, 0)
X = pos_encoding(tf.zeros((2, num_steps, depth)), training=False)
P = pos_encoding.P[:, : X.shape[1], :]
plotter = Plot()
plotter.plot_pe(np.arange(7, 11), P, num_steps)
Expand All @@ -39,12 +41,12 @@ def test_transpose_qkv():
X = tf.ones((2, 100, 24))
valid_lens = tf.constant([3, 2])
norm_shape = [i for i in range(len(X.shape))][1:]
encoder_blk = TransformerEncoderBlock(24, 24, 24, 24, norm_shape, 48, 8, 0.5)
print(encoder_blk(X, valid_lens, training=False))
encoder_block = TransformerEncoderBlock(24, norm_shape, 48, 8, 0.5)
print(encoder_block(X, valid_lens, training=False))

encoder = TransformerEncoder(200, 24, 24, 24, 24, [1, 2], 48, 8, 2, 0.5)
encoder = TransformerEncoder(200, 24, [1, 2], 48, 8, 2, 0.5)
print(encoder(tf.ones((2, 100)), valid_lens, training=False).shape, (2, 100, 24))

decoder_blk = TransformerDecoderBlock(24, 24, 24, 24, [1, 2], 48, 8, 0.5, 0)
state = [encoder_blk(X, valid_lens), valid_lens, [None]]
print(decoder_blk(X, state, training=False)[0].shape, X.shape)
decoder_block = TransformerDecoderBlock(24, [1, 2], 48, 8, 0.5, 0)
state = [encoder_block(X, valid_lens), valid_lens, [None]]
print(decoder_block(X, state, training=False)[0].shape, X.shape)
2 changes: 1 addition & 1 deletion txplot/plot_pe.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ def plot_pe(
num_steps,
show_grid=True,
):
ax = plt.figure(figsize=(6, 2.5))
ax = plt.figure(figsize=(6, 2.5), dpi=1000)

lines = ["-", "--", "-.", ":"]
self.line_cycler = cycle(lines)
Expand Down

0 comments on commit 40de79e

Please sign in to comment.