Skip to content

Commit

Permalink
feat(vhdl): remove translatable protocol
Browse files Browse the repository at this point in the history
  • Loading branch information
julianhoever committed Sep 7, 2022
1 parent 37412e8 commit ca52a92
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 9 deletions.
8 changes: 4 additions & 4 deletions elasticai/creator/examples/translate_lstm_linear_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,24 +38,24 @@ def main() -> None:
fixed_point_factory = FixedPoint.get_factory(total_bits=8, frac_bits=4)
work_library_name = "xil_defaultlib"
translation_args = dict(
LSTMTranslatable=LSTMTranslationArgs(
LSTMModule=LSTMTranslationArgs(
fixed_point_factory=fixed_point_factory,
sigmoid_resolution=(-2.5, 2.5, 256),
tanh_resolution=(-1, 1, 256),
work_library_name=work_library_name,
),
Linear1dTranslatable=Linear1dTranslationArgs(
Linear1dModule=Linear1dTranslationArgs(
fixed_point_factory=fixed_point_factory,
work_library_name=work_library_name,
),
)

translatable_layers = translator.translate_model(
vhdl_modules = translator.translate_model(
model=model, build_function_mapping=DEFAULT_BUILD_FUNCTION_MAPPING
)

code_repr = translator.generate_code(
vhdl_modules=translatable_layers, translation_args=translation_args
vhdl_modules=vhdl_modules, translation_args=translation_args
)

translator.save_code(code_repr=code_repr, path=args.build_dir)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,10 @@ def setUp(self) -> None:
self.lstm.bias_hh_l0 = arange_parameter(start=12, end=16, shape=(4,))

def test_build_lstm_layer_weights_correct_set(self) -> None:
lstm_translatable = build_lstm(self.lstm)
lstm_module = build_lstm(self.lstm)

self.assertEqual(lstm_translatable.weights_ih, [[[0.0], [1.0], [2.0], [3.0]]])
self.assertEqual(lstm_translatable.weights_hh, [[[4.0], [5.0], [6.0], [7.0]]])
self.assertEqual(lstm_module.weights_ih, [[[0.0], [1.0], [2.0], [3.0]]])
self.assertEqual(lstm_module.weights_hh, [[[4.0], [5.0], [6.0], [7.0]]])

self.assertEqual(lstm_translatable.biases_ih, [[8.0, 9.0, 10.0, 11.0]])
self.assertEqual(lstm_translatable.biases_hh, [[12.0, 13.0, 14.0, 15.0]])
self.assertEqual(lstm_module.biases_ih, [[8.0, 9.0, 10.0, 11.0]])
self.assertEqual(lstm_module.biases_hh, [[12.0, 13.0, 14.0, 15.0]])

0 comments on commit ca52a92

Please sign in to comment.